diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..a85a2e7 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM zibo.harbor.iluvatar.com.cn:30000/saas/bi100-3.2.1-x86-ubuntu20.04-py3.10-poc-llm-infer:20250731115755 + +RUN pip install --no-cache-dir triton==2.1.0 + +COPY pkgs/triton /usr/local/corex/lib64/python3/dist-packages/triton +COPY pkgs/triton-2.1.0+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/triton-2.1.0+corex.4.1.2.dist-info +COPY pkgs/xformers-0.0.22+corex.4.1.2.dist-info /usr/local/corex/lib64/python3/dist-packages/xformers-0.0.22+corex.4.1.2.dist-info +COPY pkgs/xformers /usr/local/corex/lib64/python3/dist-packages/xformers + +COPY paged_attn.py /usr/local/lib/python3.10/site-packages/vllm/attention/ops/paged_attn.py +COPY __init__.py /usr/local/lib/python3.10/site-packages/vllm/triton_utils/__init__.py +COPY prefix_prefill.py /usr/local/lib/python3.10/site-packages/vllm/attention/ops/prefix_prefill.py + +RUN mkdir /workspace +WORKDIR /workspace/ + +COPY ./launch_service /workspace/launch_service + +ENTRYPOINT ["./launch_service"] diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..4e19581 --- /dev/null +++ b/__init__.py @@ -0,0 +1,9 @@ +from vllm.triton_utils.importing import HAS_TRITON + +__all__ = ["HAS_TRITON"] + +#from vllm.triton_utils.custom_cache_manager import ( +# maybe_set_triton_cache_manager) +#from vllm.triton_utils.libentry import libentry + +__all__ += ["maybe_set_triton_cache_manager", "libentry"] diff --git a/attention.py b/attention.py new file mode 100644 index 0000000..0d10945 --- /dev/null +++ b/attention.py @@ -0,0 +1,542 @@ +"""Multi-head attention.""" +import os +enable_infer_paged_attn = os.getenv("ENABLE_INFER_PAGED_ATTN",None) +from typing import List, Optional + +import importlib +import torch +import torch.nn as nn +from ixformer.contrib.xformers import ops as xops +from ixformer.contrib.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) + +from vllm._C import ops +from vllm._C import cache_ops +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( + context_attention_fwd) +from vllm.utils import is_hip + +# _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +# # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +# _PARTITION_SIZE = 512 +_SUPPORTED_HEAD_SIZES = [64, 128, 256] +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 256 + + +class PagedAttention(nn.Module): + """MHA/MQA/GQA layer with PagedAttention. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Reshape and store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention using either + xformers or the PagedAttention custom op. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + self.use_ref_attention = self.check_use_ref_attention() + + # TODO align vllm do not need those + self.attn_op = xops.fmha.flash.FwOp() + head_mapping = torch.repeat_interleave( + torch.arange(self.num_kv_heads, dtype=torch.int32), + self.num_queries_per_kv) + self.register_buffer("head_mapping", head_mapping, persistent=False) + + def check_use_ref_attention(self) -> bool: + if not is_hip(): + return False + # For ROCm, check whether flash attention is installed or not. + # if not, use_ref_attention needs to be True + return importlib.util.find_spec("flash_attn") is None + + def ref_masked_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query, + key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + """PagedAttention forward pass. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for the inputs. + cache_event: event to wait for the cache operations to finish. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + slot_mapping = input_metadata.slot_mapping + + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + ) + + if input_metadata.is_prompt: + # normal attention + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens) + input_metadata.attn_bias = attn_bias + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(num_tokens, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=self.attn_op, + alibi_slopes=self.alibi_slopes + ) + output = out.view_as(query) + else: + # prefix-enabled attention + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) + else: + # Decoding run. + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.head_mapping, # self.num_kv_heads + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + # TODO align + """ + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + PagedAttention forward pass. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for the inputs. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) + + if input_metadata.is_prompt: + # normal attention + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) + else: + # prefix-enabled attention + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) + + else: + # Decoding run. + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + """ + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, +) -> LowerTriangularMaskWithTensorBias: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + batch_size, + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + return attn_bias + + +def _paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + head_mapping: torch.Tensor, # num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + use_sqrt_alibi: bool = False +) -> torch.Tensor: + output = torch.empty_like(query) + + use_v2 = enable_infer_paged_attn is None and key_cache.dim() == 4 + if not use_v2: + block_size = value_cache.shape[3] + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, # num_kv_heads + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + else: + # Run PagedAttention V2. + block_size = value_cache.shape[2] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, # num_kv_heads + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + return output + + +# ↓ add for smoothquant +class DequantPagedAttention(PagedAttention): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + quant_kv_cache: bool = False, + kv_quant_params: torch.Tensor = None, + quant_scale: float = 1.0, + use_per_token_quant: bool = True, + ) -> None: + super().__init__(num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window) + self.register_parameter( + "quant_scale", + torch.nn.Parameter( + torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False)) + ) + self.use_per_token_quant = use_per_token_quant + + def _apply(self, fn): + super()._apply(fn) + self.quant_scale.data = self.quant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.quant_scale.data = self.quant_scale.to(*args, **kwargs) + self.quant_scale.data = self.quant_scale.to(torch.float32) + return self + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + out = super().forward( + query, + key, + value, + key_cache, + value_cache, + input_metadata, + ) + quant_out = torch.empty_like(out, dtype=torch.int8) + if self.use_per_token_quant: + scale = torch.empty(out.numel() // out.shape[-1], + dtype=torch.float32, + device=out.device) + ops.quant(quant_out, out, scale) + return quant_out, scale + else: + ops.quant(quant_out, out, self.quant_scale.item()) + return (quant_out, ) diff --git a/launch_service b/launch_service new file mode 100755 index 0000000..f1571ed --- /dev/null +++ b/launch_service @@ -0,0 +1,79 @@ +#!/bin/bash + +export PYTHONPATH=/usr/local/corex/lib64/python3/dist-packages +export LD_LIBRARY_PATH=/usr/local/corex/lib64:/usr/local/openmpi/lib +export PATH=/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin +export JAVA_HOME=/root/apps/jdk1.8.0_411 +export JRE_HOME=/root/apps/jdk1.8.0_411/jre +export JMETER_HOME=/root/apps/apache-jmeter-5.6.3 +export CLASSPATH=.:/root/apps/jdk1.8.0_411/lib/dt.jar:/root/apps/jdk1.8.0_411/lib/tools.jar:/root/apps/apache-jmeter-5.6.3/lib/ext/ApacheJMeter_core.jar:/root/apps/apache-jmeter-5.6.3/lib/jorphan.jar:/root/apps/apache-jmeter-5.6.3/lib/logkit-2.0.jar: +export PATH=/root/apps/apache-jmeter-5.6.3/bin:/root/apps/jdk1.8.0_411/bin:/usr/local/corex/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/local/corex/lib64/python3/dist-packages/bin:/usr/local/openmpi/bin +/iluvatar/welcome.sh + +data +cat /proc/cpuinfo | tail -n 50 +ixsmi +unset CUDA_VISIBLE_DEVICES +export +date + +DEFAULT_HOST="0.0.0.0" +DEFAULT_PORT="80" +DEFAULT_SERVED_MODEL_NAME="llm" +DEFAULT_MODEL_PATH="/model" +DEFAULT_MAX_MODEL_LEN="10000" +DEFAULT_TENSOR_PARALLEL_SIZE="1" +DEFAULT_MAX_NUM_SEQS="64" +DEFAULT_ENFORCE_EAGER="true" +DEFAULT_DISABLE_LOG_REQUESTS="true" +DEFAULT_PREFIX_CACHING="true" + +HOST_VAL=${HOST:-$DEFAULT_HOST} +PORT_VAL=${PORT:-$DEFAULT_PORT} +SERVED_MODEL_NAME_VAL=${SERVED_MODEL_NAME:-$DEFAULT_SERVED_MODEL_NAME} +MODEL_PATH_VAL=${MODEL_PATH:-$DEFAULT_MODEL_PATH} +MAX_MODEL_LEN_VAL=${MAX_MODEL_LEN:-$DEFAULT_MAX_MODEL_LEN} +TENSOR_PARALLEL_SIZE_VAL=${TENSOR_PARALLEL_SIZE:-$DEFAULT_TENSOR_PARALLEL_SIZE} +MAX_NUM_SEQS_VAL=${MAX_NUM_SEQS:-$DEFAULT_MAX_NUM_SEQS} +INCLUDE_ENFORCE_EAGER_FLAG=${ENFORCE_EAGER:-$DEFAULT_ENFORCE_EAGER} +INCLUDE_DISABLE_LOG_REQUESTS_FLAG=${DISABLE_LOG_REQUESTS:-$DEFAULT_DISABLE_LOG_REQUESTS} +INCLUDE_PREFIX_CACHING_FLAG=${PREFIX_CACHING:-$DEFAULT_PREFIX_CACHING} + +CMD_ARGS=() +CMD_ARGS+=(--host "$HOST_VAL") +CMD_ARGS+=(--port "$PORT_VAL") + +if [[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]]; then + CMD_ARGS+=(--enforce-eager) +fi +if [[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]]; then + CMD_ARGS+=(--disable-log-requests) +fi +if [[ "$INCLUDE_PREFIX_CACHING_FLAG" != "false" && "$INCLUDE_PREFIX_CACHING_FLAG" != "0" ]]; then + CMD_ARGS+=(--enable-prefix-caching) +fi + +CMD_ARGS+=(--served-model-name "$SERVED_MODEL_NAME_VAL") +CMD_ARGS+=(--model "$MODEL_PATH_VAL") +CMD_ARGS+=(--max-model-len "$MAX_MODEL_LEN_VAL") +CMD_ARGS+=(--tensor-parallel-size "$TENSOR_PARALLEL_SIZE_VAL") +CMD_ARGS+=(--max-num-seqs "$MAX_NUM_SEQS_VAL") + +echo "--------------------------------------------------" +echo "Starting VLLM OpenAI API Server..." +echo "Using effective arguments:" +echo " Host (--host): $HOST_VAL" +echo " Port (--port): $PORT_VAL" +echo " Enforce Eager (--enforce-eager):" $([[ "$INCLUDE_ENFORCE_EAGER_FLAG" != "false" && "$INCLUDE_ENFORCE_EAGER_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: ENFORCE_EAGER=$ENFORCE_EAGER)") +echo " Disable Log Req (--disable-log-requests):" $([[ "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "false" && "$INCLUDE_DISABLE_LOG_REQUESTS_FLAG" != "0" ]] && echo "Enabled" || echo "Disabled (Env: DISABLE_LOG_REQUESTS=$DISABLE_LOG_REQUESTS)") +echo " Served Model Name (--served-model-name): $SERVED_MODEL_NAME_VAL" +echo " Model Path (--model): $MODEL_PATH_VAL" +echo " Max Model Length (--max-model-len): $MAX_MODEL_LEN_VAL" +echo " Tensor Parallel Size (--tensor-parallel-size): $TENSOR_PARALLEL_SIZE_VAL" +echo " Max Num Seqs (--max-num-seqs): $MAX_NUM_SEQS_VAL" +echo "--------------------------------------------------" +echo "Full cmd:" +echo "python3 -m vllm.entrypoints.openai.api_server ${CMD_ARGS[*]}" +echo "--------------------------------------------------" + +python3 -m vllm.entrypoints.openai.api_server "${CMD_ARGS[@]}" diff --git a/paged_attn.py b/paged_attn.py new file mode 100644 index 0000000..1741dd1 --- /dev/null +++ b/paged_attn.py @@ -0,0 +1,242 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_seq_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + use_v1 = True + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: float, + v_scale: float, + ) -> torch.Tensor: + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], + seq_lens_tensor, + context_lens, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/INSTALLER b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/METADATA b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/METADATA new file mode 100644 index 0000000..ee61ccf --- /dev/null +++ b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/METADATA @@ -0,0 +1,33 @@ +Metadata-Version: 2.1 +Name: triton +Version: 2.1.0+corex.4.1.2 +Summary: A language and compiler for custom Deep Learning operations +Home-page: https://github.com/openai/triton/ +Author: Philippe Tillet +Author-email: phil@openai.com +Keywords: Compiler,Deep Learning +Classifier: Development Status :: 4 - Beta +Classifier: Intended Audience :: Developers +Classifier: Topic :: Software Development :: Build Tools +Classifier: License :: OSI Approved :: MIT License +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Requires-Dist: filelock +Provides-Extra: build +Requires-Dist: cmake>=3.18; extra == "build" +Requires-Dist: lit; extra == "build" +Provides-Extra: tests +Requires-Dist: autopep8; extra == "tests" +Requires-Dist: flake8; extra == "tests" +Requires-Dist: isort; extra == "tests" +Requires-Dist: numpy; extra == "tests" +Requires-Dist: pytest; extra == "tests" +Requires-Dist: scipy>=1.7.1; extra == "tests" +Provides-Extra: tutorials +Requires-Dist: matplotlib; extra == "tutorials" +Requires-Dist: pandas; extra == "tutorials" +Requires-Dist: tabulate; extra == "tutorials" + diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/RECORD b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/RECORD new file mode 100644 index 0000000..920db50 --- /dev/null +++ b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/RECORD @@ -0,0 +1,156 @@ +triton-2.1.0+corex.4.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +triton-2.1.0+corex.4.1.2.dist-info/METADATA,sha256=WsBBXGUNH3GMUl-Sr8PVEA9wKr8pKyeQQ7BeQEq4Moc,1271 +triton-2.1.0+corex.4.1.2.dist-info/RECORD,, +triton-2.1.0+corex.4.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +triton-2.1.0+corex.4.1.2.dist-info/WHEEL,sha256=jkyzK-PMDGe6NpviSR1tpgon24csQEXDXqxhDgOHeYM,105 +triton-2.1.0+corex.4.1.2.dist-info/direct_url.json,sha256=rC0cgRYuAV96ZhqohICZYGcWgh85fJNAUbExKOBJuaE,284 +triton-2.1.0+corex.4.1.2.dist-info/top_level.txt,sha256=g7iYLuhGmQFd2TwONpkqhhg7KJl6GMjbwGZvSp3_tnE,275 +triton/_C/libtriton.so,sha256=TabhXpykxBhWhTTiPjVzF4qoFgKy2B5-GV04hVDF1HY,256934256 +triton/__init__.py,sha256=J6RjARnKOLfAU59wgewEy82ZoICORUQe_mfHWHQ87lY,1267 +triton/__pycache__/__init__.cpython-310.pyc,, +triton/__pycache__/testing.cpython-310.pyc,, +triton/common/__init__.py,sha256=uG0IrQy1GiwfCkg9FHqVcYHQ9dspmDtma8SSdFqKXew,48 +triton/common/__pycache__/__init__.cpython-310.pyc,, +triton/common/__pycache__/build.cpython-310.pyc,, +triton/common/build.py,sha256=Ssz_enawYCeVbyHsmJRk9e81dvlFCIWNKikWcWJjZBo,4609 +triton/compiler/__init__.py,sha256=VGVviF6fZSXDMUrcO9o0H96j9gQXcYDjnJx8TG18xhU,144 +triton/compiler/__pycache__/__init__.cpython-310.pyc,, +triton/compiler/__pycache__/code_generator.cpython-310.pyc,, +triton/compiler/__pycache__/compiler.cpython-310.pyc,, +triton/compiler/__pycache__/errors.cpython-310.pyc,, +triton/compiler/__pycache__/make_launcher.cpython-310.pyc,, +triton/compiler/code_generator.py,sha256=2T0mFAy4TUOh1v2Mh3csLtztT9Aky0zNH-fjfLqx5W8,49616 +triton/compiler/compiler.py,sha256=0_H9l2n4MlHKA9zfploeC9nF-oWC8sw5oUOvgGirRwU,23506 +triton/compiler/errors.py,sha256=PiquMxHuHayRvdH3hMXSQlOnDswK3TGNWENB29YX_FU,1666 +triton/compiler/make_launcher.py,sha256=dhi2y8cTVkaFUAFfH-s7gBq40ukhq18_Rl674lsS408,12458 +triton/debugger/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +triton/debugger/__pycache__/__init__.cpython-310.pyc,, +triton/debugger/__pycache__/core.cpython-310.pyc,, +triton/debugger/__pycache__/debugger.cpython-310.pyc,, +triton/debugger/__pycache__/memory_map.cpython-310.pyc,, +triton/debugger/__pycache__/tl_lang.cpython-310.pyc,, +triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc,, +triton/debugger/core.py,sha256=zjYVVAleV3EdBMOhMUR-c1X9H3613fu1ttOsG5n0VgQ,150 +triton/debugger/debugger.py,sha256=-G9B4apnYC1woN-dTwdBkk4EdS-bavFnJ763_kvUr6o,5399 +triton/debugger/memory_map.py,sha256=7xvtTJsJoGYqX4bwsJJyxJvl8H5YOnA3nnqwF5zifi8,3165 +triton/debugger/tl_lang.py,sha256=P2ZZa9QNkLxp63PY7kSdTrEWU4mF5zzB3hsBIm1Xum0,17980 +triton/debugger/torch_wrapper.py,sha256=iD0qbPnsR-rMG0anMJRoEhIDvQaVzFmBuFTKkPmnj_g,353 +triton/language/__init__.py,sha256=lLiZyE7r2PJ3vQcUfw2EAkFaBJHLzG4YRUlJKGlOxrk,2884 +triton/language/__pycache__/__init__.cpython-310.pyc,, +triton/language/__pycache__/core.cpython-310.pyc,, +triton/language/__pycache__/math.cpython-310.pyc,, +triton/language/__pycache__/random.cpython-310.pyc,, +triton/language/__pycache__/semantic.cpython-310.pyc,, +triton/language/__pycache__/standard.cpython-310.pyc,, +triton/language/core.py,sha256=cxhJ3qG4h2T85FUuaVdDTAdB7eXkTHzmaerFZd57cow,54553 +triton/language/extra/__init__.py,sha256=zyhlj6Mo9_KD7E5szGxI18Ko3b31E00-s1ybOM5KzkQ,39 +triton/language/extra/__pycache__/__init__.cpython-310.pyc,, +triton/language/extra/__pycache__/cuda.cpython-310.pyc,, +triton/language/extra/cuda.bc,sha256=03xfoq03DK6PfMS2fm05V90ei5jlnDdZaD6kZffQEas,1808 +triton/language/extra/cuda.py,sha256=7UCCXfkwgLPpl87FRZsXyrHN6tt0xRoQeeBksSMdgyE,641 +triton/language/math.py,sha256=g1z24oHGwyfZZNPkp9OUQ0_nXlnGEQCwAMcKKrRt1kE,75450 +triton/language/random.py,sha256=yxzlT8jCKAv0yo4aVcWL5kBj4OOMyRcpbuGXMUJw0E0,5630 +triton/language/semantic.py,sha256=hmyErm5bWEvlt2NgWPzubVUNBh4Bbfrb5BJ6y42VScM,62939 +triton/language/standard.py,sha256=jpc-ha22ylEHusr-f4OUSLo7iuso3yF4cTjOTpSMqo0,2320 +triton/ops/__init__.py,sha256=KM696wNNUi4gHanlZo54WcfmYnRWkykHKGnuhRoqAgQ,370 +triton/ops/__pycache__/__init__.cpython-310.pyc,, +triton/ops/__pycache__/bmm_matmul.cpython-310.pyc,, +triton/ops/__pycache__/cross_entropy.cpython-310.pyc,, +triton/ops/__pycache__/flash_attention.cpython-310.pyc,, +triton/ops/__pycache__/matmul.cpython-310.pyc,, +triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc,, +triton/ops/blocksparse/__init__.py,sha256=6YEVQNzipgQCpoO_7B8H7ckaSW2Idt1244s7IyLWAwc,100 +triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc,, +triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc,, +triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc,, +triton/ops/blocksparse/matmul.py,sha256=uBOKf2pS5fFtGokQvrWS0GrVHNipIrsTZ_-vPr5lfOY,15617 +triton/ops/blocksparse/softmax.py,sha256=eFvPebQZ__8L4pCmks-PsqpCG2pgheRL2f0Z4WTKrts,7893 +triton/ops/bmm_matmul.py,sha256=grVD0Z800EWMC35TtZZoGgQihZqRUfnjJ9PfOpJmhfA,6381 +triton/ops/cross_entropy.py,sha256=NRZdHy8nh70E0B_7BIHc-3-rkQF5Nb9nBnbfxfjHUAs,3457 +triton/ops/flash_attention.py,sha256=AFjILTHWCkYXHDNkEBdrcsYIQ4OmYvYsKh7G2StXfH8,10335 +triton/ops/matmul.py,sha256=bMoTeSnGBrR5oeSgtUr6syeLNPepydVpy1l9Ts-5E-g,8447 +triton/ops/matmul_perf_model.py,sha256=KpYV6bJ9dRqg2js-1vLm_uPQyOytbUT8jaSVeleIXmI,6672 +triton/runtime/__init__.py,sha256=4N2bT8rp10BAMHQzF9CiIt276t7XhA22c1uJkNB2ejk,516 +triton/runtime/__pycache__/__init__.cpython-310.pyc,, +triton/runtime/__pycache__/autotuner.cpython-310.pyc,, +triton/runtime/__pycache__/cache.cpython-310.pyc,, +triton/runtime/__pycache__/driver.cpython-310.pyc,, +triton/runtime/__pycache__/errors.cpython-310.pyc,, +triton/runtime/__pycache__/jit.cpython-310.pyc,, +triton/runtime/autotuner.py,sha256=CYq8EkWf2mzAgubAvHheLtpsCYbrEdu7BbwKxYIhips,12718 +triton/runtime/backends/cuda.c,sha256=5s9niGGe1tCvt1AW3O66jqiYQtpHQM9aUlKXppnTvO0,4680 +triton/runtime/backends/hip.c,sha256=L37Uz-DsX5ntV_nI3OknmtB_xuufIWrfZAvW4SafFTA,4005 +triton/runtime/cache.py,sha256=kpcFPdEbH3cO60HVdzhYtnOuIycXJnbDqWhbuF5hkFQ,4190 +triton/runtime/driver.py,sha256=UIsodifreXt2d9NLErHR7wu7Fbr4nCpkhfOtYniMhnE,5100 +triton/runtime/errors.py,sha256=XuA6URwCy4e3iYYbX9k35X89wnoasp7_K1LAPvsTbMQ,591 +triton/runtime/jit.py,sha256=9LQhkSlhcKoWzx_x6E9E51eByfAj1U7yjwAyj-Mh57g,21468 +triton/testing.py,sha256=akS7bFcE0LVMi56WCZSt7Wu4CfkRGsasJ9_tL1Zj3Vo,15520 +triton/third_party/cuda/bin/ptxas,sha256=65yHgyKTqojU5uWOzQ134l3rk9pcGou5HefXzVsCgaw,20495600 +triton/third_party/cuda/include/cuda.h,sha256=yCBNCpZt_qVabzyVattc2qrEq36p5MeL3kDzvNdeV4A,790319 +triton/third_party/cuda/lib/libdevice.10.bc,sha256=XC-uN8huaMOjhgWpX1EtfRLV89uYYxC-R_VzBKpype4,473728 +triton/third_party/rocm/lib/bitcode/asanrtl.bc,sha256=cHu7_b3BolfsDJMlWnGTY71qI6fNNdSVPDMFIojOh9k,24708 +triton/third_party/rocm/lib/bitcode/cuda2gcn.bc,sha256=UDl-zwwWB5Ad__lIlhReRQBx61m0nMORxdppdTxkjQo,41568 +triton/third_party/rocm/lib/bitcode/cuda2gcn.patch,sha256=EuHo5MOHjwbqdnuv53532LaFAuRh3C3J1QfRo0B3NJM,600 +triton/third_party/rocm/lib/bitcode/hip.bc,sha256=yttUFcAzSE3ydohSCvOlTp5QZ3SjSH4juoSOYXz_zGg,2372 +triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh,sha256=9jMtbtlZBIY6yyEp8V12-6Nw4WahdrCPFyFnRb1kpls,348 +triton/third_party/rocm/lib/bitcode/ockl.bc,sha256=S3XZYIoykPyJ8XOFeVdqyRcdqxoBu50-AaU8oTdAhS8,227392 +triton/third_party/rocm/lib/bitcode/oclc_abi_version_400.bc,sha256=K3TaLRC6uoHXTMK6WS7NgLJzjXeKKfvf_8a6lqfNwAs,1920 +triton/third_party/rocm/lib/bitcode/oclc_abi_version_500.bc,sha256=9e6Z5f4lm_5asnMSHuPOrrnwlV6jlStgrwYTEiOU5tc,1920 +triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_off.bc,sha256=DiRiGcUfL17ukJS40RV8OHCWvEU-Oi4oS5NCn-SRAFY,1936 +triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_on.bc,sha256=cVJuiMvffHLxRqvPmzpbMp8nX7yxtUojn9A-iBfCmaY,1936 +triton/third_party/rocm/lib/bitcode/oclc_daz_opt_off.bc,sha256=hcifwxlJIx5HlOcj_LjOUeqezaZ2v1xzX6LHNSQSkbE,1920 +triton/third_party/rocm/lib/bitcode/oclc_daz_opt_on.bc,sha256=jYa1MyGdVRH6etbC6gicJ8egrwwV4wVHbe4aHRQ27R8,1920 +triton/third_party/rocm/lib/bitcode/oclc_finite_only_off.bc,sha256=gWpVuB2OXig5vPczOQshtvZ_oJO72R5fFRiAYTILhgM,1928 +triton/third_party/rocm/lib/bitcode/oclc_finite_only_on.bc,sha256=z9bPfFcxKqMILV3SsRJUIk6ExpZHak2Asv8FP9Ni9bw,1928 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1010.bc,sha256=0vFumpEzejM0Xz7hSSbEYT-iZCrPrDeTRicpD-H-jXM,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1011.bc,sha256=VqawQ_trHRs7rn9NQEkKQ0gDO1BW73hkl9ZdP5Ixybc,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1012.bc,sha256=2C8uRNtmebw0dHh8Vg5_PuFLiY4x19wcydnS7VvadOo,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1013.bc,sha256=lZfO_8vUqSCn1o3VDALvFN-Cclq3eOYDe7-FthD0Gng,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1030.bc,sha256=WIadAgcRBfWphRmcrQJUYewZXWtSN7npAAMAkNQ0C1w,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1031.bc,sha256=ofxROlV79qeW1SyCo9VmJ-W9CVK7zxZtbWLDA-nKRh0,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1032.bc,sha256=IXOzwf4sdsw453PcColeLVjEQek4H1X3G-IfqvoZfhA,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1033.bc,sha256=4KNBcZx7R7e6CJE-4-Ge3v4xzOIuGOOOdWhpga_txbU,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1034.bc,sha256=nmMcRhEdWZcSzMbO0VCj98VTujkBsswVyQiozF4Pgo4,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1035.bc,sha256=EfSk25ZwCSNNBVOhKoGvkKpPXmW5WSYPTCqD0PCwAB0,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1036.bc,sha256=L7hI21xe_IX3IFyQf_A73k1YXEf95w16rLV1BB6FwQE,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1100.bc,sha256=_UxQflZQXU8FJcg7B6VKJtJJwJRKdDY2qiauHkP6nFs,2096 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1101.bc,sha256=Lb83xl4VE1cQjOxz0gMwWS1_p-sq1wHjs67DpEI2RL0,2096 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1102.bc,sha256=Xqalbf-GktJPs_ZOrg54FCCdxzMAZ4U7Ny7HNqh4Blc,2096 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_1103.bc,sha256=KMSqCG4J8tDGDRUoYK97xw-Gj0CeFJN0bYVn7Y8ot70,2096 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_600.bc,sha256=mkocG9cx-o5vwHCT98Fcq4v_fol3RUk24VPaJ2KNiJ0,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_601.bc,sha256=Ib-dn5dosOuXVbBdG59-RERNDk2eMX4vvWUm0Z8Qqpc,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_602.bc,sha256=QH5FGGkLH3fFhPJneGRrucoP44FBTKjm7VMKzGSHeDA,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_700.bc,sha256=ibu033_NCSnqAYqqD70CRVLeym0RFpgm-PX0q-Y2yDs,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_701.bc,sha256=P3d62xvb3_OSmACax3BeobuBAOP04X_0FjXQDKIL7bk,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_702.bc,sha256=WmuD3Kq09F2KQ3u6mFCfE0MoDZWQAfjwDkOHwZpBnsA,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_703.bc,sha256=3_iPbveGJ9fJa1Kvu-uEEGbuRAOrbTMauJGJNFPBvHI,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_704.bc,sha256=0PEze_BCWRcbGBeM_Jaote9xmzsW7hDCwbeuHs55PGo,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_705.bc,sha256=QBqWlg1VDWgQPXDvuPzWQ8--WRP3oSyk25nv70yQXMU,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_801.bc,sha256=iQKrQ9rQx2UiET0BQ-1RgibWE4_aouTdPoaTg0gUEZk,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_802.bc,sha256=9euokmFJ4pmx6IfkcBCpnHoRJrHqQ1I_zZMIYY7j7Ks,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_803.bc,sha256=-v7D1qhgp4YwsFgYL2R3gN4cN0vR-6RGxMGxMm7RpKM,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_805.bc,sha256=MalaxEN85FaTuJ9CUahDNKaCR2CZj81YNzVeXOHsDOA,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_810.bc,sha256=8Dbv2N9xQB93CzJ0DWe1YLpstI6SpEH71XABjEJu7BQ,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_900.bc,sha256=j3tpNoOQGtuHd4isYNNPdtYVGSmeXbIiHfLgPRioS8Q,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_902.bc,sha256=3SlZ_VGBX9y8cRm9hF3XdtadPFJFLsIaoOnGdnqtdFY,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_904.bc,sha256=OLEXkdaKEefZgzsommw6c5378MDasYKH_PSkGoU7hGU,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_906.bc,sha256=KeOq_NoSwkh3D-1eYbzEvYcLqx-a4XHlMib3GQhtoFY,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_908.bc,sha256=PFAdOwLxneisb1dnPMU8nkm8kslmNuCGWbVFOP9bgj0,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_909.bc,sha256=J4Cwry1PZErUfvaYu4ZNCct1NQMDle0mz84ycbKJp6g,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_90a.bc,sha256=ojZ7tLg2N7Hq1omZIrCkBRzia0bYtOdGSU4CmkkdfhE,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_90c.bc,sha256=6mVoOg82NG_NfT2Su6fB18DFgRBRtOSQnRL2QgjRdHY,1920 +triton/third_party/rocm/lib/bitcode/oclc_isa_version_940.bc,sha256=Z5ahEe7z2mwBgxuBySEOG7-rRRdpVACwawch5PE6s2I,1920 +triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_off.bc,sha256=evf88YMj0hBicj04zDdTM0Xlu4KG9fpUlj3iGS0F87A,1928 +triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_on.bc,sha256=bYF53Lu769UGewikZaQMfoxOerQBGs3SfLpQXoj3zVk,1928 +triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_off.bc,sha256=XlcwAnD8EeEHQ6FksQG3-Rxc07zMt7aFXVd-xmGBeDo,1928 +triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_on.bc,sha256=s7t5si7PLY1kIAs0oJGY-9uSL-lMKtcXdGZqzXyFteI,1928 +triton/third_party/rocm/lib/bitcode/ocml.bc,sha256=bWAX1FuKc-cJrmLlC6kYhvs5fzRoBbc0Y7BP5qgg0BA,192228 +triton/third_party/rocm/lib/bitcode/opencl.bc,sha256=mL2MGdVis93KLQWVYYj0pT6d8iS99DKlbnEe-1VGdss,2841028 +triton/tools/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +triton/tools/__pycache__/__init__.cpython-310.pyc,, +triton/tools/__pycache__/aot.cpython-310.pyc,, +triton/tools/__pycache__/build_extern.cpython-310.pyc,, +triton/tools/__pycache__/disasm.cpython-310.pyc,, +triton/tools/aot.py,sha256=DUb9edluQKqfvY6aeIyQk2dN1Ljv3i6GYystyObqc8o,4752 +triton/tools/build_extern.py,sha256=t5vdo6a3bcFfawMQBnTderw7Xi-L49x4udomoOWRUtw,14661 +triton/tools/disasm.py,sha256=FOtqCFjUedLoZtb1hektBD32L1YxC43qnWKZr6VXc5E,4593 diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/REQUESTED b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/REQUESTED new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/WHEEL b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/WHEEL new file mode 100644 index 0000000..591b31a --- /dev/null +++ b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.44.0) +Root-Is-Purelib: false +Tag: cp310-cp310-linux_x86_64 + diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/direct_url.json b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/direct_url.json new file mode 100644 index 0000000..987170b --- /dev/null +++ b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/direct_url.json @@ -0,0 +1 @@ +{"archive_info": {"hash": "sha256=0b2734afe117d3c56cda0a7cd1f4752510dc91d8678638321dcbaee07110da40", "hashes": {"sha256": "0b2734afe117d3c56cda0a7cd1f4752510dc91d8678638321dcbaee07110da40"}}, "url": "file:///usr/local/apps_whl/triton-2.1.0%2Bcorex.4.1.2-cp310-cp310-linux_x86_64.whl"} \ No newline at end of file diff --git a/pkgs/triton-2.1.0+corex.4.1.2.dist-info/top_level.txt b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/top_level.txt new file mode 100644 index 0000000..25a755c --- /dev/null +++ b/pkgs/triton-2.1.0+corex.4.1.2.dist-info/top_level.txt @@ -0,0 +1,15 @@ +triton +triton/_C +triton/common +triton/compiler +triton/debugger +triton/language +triton/language/extra +triton/ops +triton/ops/blocksparse +triton/runtime +triton/runtime/backends +triton/third_party/cuda/bin +triton/third_party/cuda/include +triton/third_party/cuda/lib +triton/tools diff --git a/pkgs/triton/_C/libtriton.so b/pkgs/triton/_C/libtriton.so new file mode 100755 index 0000000..3871d62 Binary files /dev/null and b/pkgs/triton/_C/libtriton.so differ diff --git a/pkgs/triton/__init__.py b/pkgs/triton/__init__.py new file mode 100644 index 0000000..14c9d61 --- /dev/null +++ b/pkgs/triton/__init__.py @@ -0,0 +1,68 @@ +"""isort:skip_file""" +__version__ = '2.1.0' + +# --------------------------------------- +# Note: import order is significant here. + +# submodules +from .runtime import ( + autotune, + Config, + heuristics, + JITFunction, + KernelInterface, + reinterpret, + TensorWrapper, + OutOfResources, + MockTensor, +) +from .runtime.jit import jit +from .compiler import compile, CompilationError +from .debugger.debugger import program_ids_from_grid + +from . import language +from . import testing + +__all__ = [ + "autotune", + "cdiv", + "CompilationError", + "compile", + "Config", + "heuristics", + "impl", + "jit", + "JITFunction", + "KernelInterface", + "language", + "MockTensor", + "next_power_of_2", + "ops", + "OutOfResources", + "reinterpret", + "runtime", + "TensorWrapper", + "testing", + "program_ids_from_grid", +] + + +# ------------------------------------- +# misc. utilities that don't fit well +# into any specific module +# ------------------------------------- + +def cdiv(x, y): + return (x + y - 1) // y + + +def next_power_of_2(n): + """Return the smallest power of 2 greater than or equal to n""" + n -= 1 + n |= n >> 1 + n |= n >> 2 + n |= n >> 4 + n |= n >> 8 + n |= n >> 16 + n += 1 + return n diff --git a/pkgs/triton/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..973085a Binary files /dev/null and b/pkgs/triton/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/__pycache__/testing.cpython-310.pyc b/pkgs/triton/__pycache__/testing.cpython-310.pyc new file mode 100644 index 0000000..de983fd Binary files /dev/null and b/pkgs/triton/__pycache__/testing.cpython-310.pyc differ diff --git a/pkgs/triton/common/__init__.py b/pkgs/triton/common/__init__.py new file mode 100644 index 0000000..cc4d1e1 --- /dev/null +++ b/pkgs/triton/common/__init__.py @@ -0,0 +1,3 @@ +from .build import _build + +__all__ = ["_build"] diff --git a/pkgs/triton/common/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/common/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..5b68b30 Binary files /dev/null and b/pkgs/triton/common/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/common/__pycache__/build.cpython-310.pyc b/pkgs/triton/common/__pycache__/build.cpython-310.pyc new file mode 100644 index 0000000..bc37eb4 Binary files /dev/null and b/pkgs/triton/common/__pycache__/build.cpython-310.pyc differ diff --git a/pkgs/triton/common/build.py b/pkgs/triton/common/build.py new file mode 100644 index 0000000..d1659e6 --- /dev/null +++ b/pkgs/triton/common/build.py @@ -0,0 +1,137 @@ +import contextlib +import functools +import io +import os +import shutil +import subprocess +import sys +import sysconfig + +import setuptools + + +# TODO: is_hip shouldn't be here +def is_hip(): + import torch + return torch.version.hip is not None + + +def is_corex(): + import torch + return hasattr(torch, "corex") and torch.corex == True + + +@functools.lru_cache() +def cuda_home_dirs(): + loc = subprocess.check_output(["whereis", "clang++"]).decode().strip().split()[1] + default_dir = os.path.dirname(os.path.dirname(loc)) + return os.getenv("CUDA_HOME", default=default_dir) + + +@functools.lru_cache() +def libcuda_dirs(): + locs = subprocess.check_output(["whereis", "libcuda.so"]).decode().strip().split()[1:] + return [os.path.dirname(loc) for loc in locs] + + +@functools.lru_cache() +def rocm_path_dir(): + return os.getenv("ROCM_PATH", default="/opt/rocm") + + +@contextlib.contextmanager +def quiet(): + old_stdout, old_stderr = sys.stdout, sys.stderr + sys.stdout, sys.stderr = io.StringIO(), io.StringIO() + try: + yield + finally: + sys.stdout, sys.stderr = old_stdout, old_stderr + + +def _build(name, src, srcdir): + if is_hip(): + hip_lib_dir = os.path.join(rocm_path_dir(), "lib") + hip_include_dir = os.path.join(rocm_path_dir(), "include") + else: + if is_corex(): + cuda_path = cuda_home_dirs() + cu_include_dir = os.path.join(cuda_path, "include") + cuda_lib_dirs = [os.path.join(cuda_path, "lib64")] + else: + cuda_lib_dirs = libcuda_dirs() + base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir) + cuda_path = os.path.join(base_dir, "third_party", "cuda") + + cu_include_dir = os.path.join(cuda_path, "include") + triton_include_dir = os.path.join(os.path.dirname(__file__), "include") + cuda_header = os.path.join(cu_include_dir, "cuda.h") + triton_cuda_header = os.path.join(triton_include_dir, "cuda.h") + if not os.path.exists(cuda_header) and os.path.exists(triton_cuda_header): + cu_include_dir = triton_include_dir + + suffix = sysconfig.get_config_var('EXT_SUFFIX') + so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix)) + # try to avoid setuptools if possible + cc = os.environ.get("CC") + if cc is None: + # TODO: support more things here. + clang = shutil.which("clang") + gcc = shutil.which("gcc") + if is_corex(): + cc = clang if clang is not None else gcc + else: + cc = gcc if gcc is not None else clang + if cc is None: + raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.") + # This function was renamed and made public in Python 3.10 + if hasattr(sysconfig, 'get_default_scheme'): + scheme = sysconfig.get_default_scheme() + else: + scheme = sysconfig._get_default_scheme() + # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install + # path changes to include 'local'. This change is required to use triton with system-wide python. + if scheme == 'posix_local': + scheme = 'posix_prefix' + py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] + + if is_hip(): + ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so]) + else: + cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so] + cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] + ret = subprocess.check_call(cc_cmd) + + if ret == 0: + return so + # fallback on setuptools + extra_compile_args = [] + library_dirs = cuda_lib_dirs + include_dirs = [srcdir, cu_include_dir] + libraries = ['cuda'] + # extra arguments + extra_link_args = [] + # create extension module + ext = setuptools.Extension( + name=name, + language='c', + sources=[src], + include_dirs=include_dirs, + extra_compile_args=extra_compile_args + ['-O3'], + extra_link_args=extra_link_args, + library_dirs=library_dirs, + libraries=libraries, + ) + # build extension module + args = ['build_ext'] + args.append('--build-temp=' + srcdir) + args.append('--build-lib=' + srcdir) + args.append('-q') + args = dict( + name=name, + ext_modules=[ext], + script_args=args, + ) + with quiet(): + setuptools.setup(**args) + return so diff --git a/pkgs/triton/compiler/__init__.py b/pkgs/triton/compiler/__init__.py new file mode 100644 index 0000000..4d62eee --- /dev/null +++ b/pkgs/triton/compiler/__init__.py @@ -0,0 +1,4 @@ +from .compiler import CompiledKernel, compile +from .errors import CompilationError + +__all__ = ["compile", "CompiledKernel", "CompilationError"] diff --git a/pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..92d270a Binary files /dev/null and b/pkgs/triton/compiler/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc b/pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc new file mode 100644 index 0000000..7170cd3 Binary files /dev/null and b/pkgs/triton/compiler/__pycache__/code_generator.cpython-310.pyc differ diff --git a/pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc b/pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc new file mode 100644 index 0000000..37e5c0e Binary files /dev/null and b/pkgs/triton/compiler/__pycache__/compiler.cpython-310.pyc differ diff --git a/pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc b/pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000..9396672 Binary files /dev/null and b/pkgs/triton/compiler/__pycache__/errors.cpython-310.pyc differ diff --git a/pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc b/pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc new file mode 100644 index 0000000..5952faa Binary files /dev/null and b/pkgs/triton/compiler/__pycache__/make_launcher.cpython-310.pyc differ diff --git a/pkgs/triton/compiler/code_generator.py b/pkgs/triton/compiler/code_generator.py new file mode 100644 index 0000000..9df9e10 --- /dev/null +++ b/pkgs/triton/compiler/code_generator.py @@ -0,0 +1,1133 @@ +import ast +import inspect +import re +import sys +import warnings +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +from .. import language +from ..language import constexpr, tensor +# ideally we wouldn't need any runtime component +from ..runtime import JITFunction +from .errors import (CompilationError, CompileTimeAssertionFailure, + UnsupportedLanguageConstruct) +from triton._C.libtriton.triton import ir + + +def mangle_ty(ty): + if ty.is_ptr(): + return 'P' + mangle_ty(ty.element_ty) + if ty.is_int(): + SIGNED = language.dtype.SIGNEDNESS.SIGNED + prefix = 'i' if ty.int_signedness == SIGNED else 'u' + return prefix + str(ty.int_bitwidth) + if ty.is_fp8(): + return 'fp8' + if ty.is_fp16(): + return 'fp16' + if ty.is_bf16(): + return 'bf16' + if ty.is_fp32(): + return 'fp32' + if ty.is_fp64(): + return 'fp64' + if ty.is_block(): + elt = mangle_ty(ty.scalar) + shape = '_'.join(map(str, ty.shape)) + return f'{elt}S{shape}S' + if ty.is_void(): + return 'V' + assert False, "Unsupported type" + + +def mangle_fn(name, arg_tys, constants): + # doesn't mangle ret type, which must be a function of arg tys + mangled_arg_names = '_'.join([mangle_ty(ty) for ty in arg_tys]) + mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) + mangled_constants = mangled_constants.replace('.', '_d_') + mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') + ret = f'{name}__{mangled_arg_names}__{mangled_constants}' + return ret + + +def _is_triton_tensor(o: Any) -> bool: + return isinstance(o, tensor) + + +def _is_constexpr(o: Any) -> bool: + return isinstance(o, constexpr) + + +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + +def _unwrap_if_constexpr(o: Any): + return o.value if isinstance(o, constexpr) else o + + +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}') + + +def _get_fn_file_line(fn): + base_fn = fn + while not isinstance(base_fn, JITFunction): + base_fn = base_fn.fn + file_name = base_fn.fn.__code__.co_filename + lines, begin_line = inspect.getsourcelines(base_fn.fn) + for line in lines: + if line.strip().startswith('@'): + begin_line += 1 + else: + break + return file_name, begin_line + + +_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels + + +class enter_sub_region: + def __init__(self, generator): + self.generator = generator + + def __enter__(self): + # record lscope & local_defs in the parent scope + self.liveins = self.generator.lscope.copy() + self.prev_defs = self.generator.local_defs.copy() + self.generator.local_defs = {} + self.insert_block = self.generator.builder.get_insertion_block() + self.insert_point = self.generator.builder.get_insertion_point() + return self.liveins, self.insert_block + + def __exit__(self, *args, **kwargs): + self.generator.builder.restore_insertion_point(self.insert_point) + self.generator.lscope = self.liveins + self.generator.local_defs = self.prev_defs + + +# Check if the given syntax node has an "early" return +class ContainsReturnChecker(ast.NodeVisitor): + def __init__(self, gscope): + self.gscope = gscope + + def _visit_stmts(self, body) -> bool: + for s in body: + if self.visit(s): + return True + return False + + def _visit_function(self, fn) -> bool: + # Currently we only support JITFunctions defined in the global scope + if isinstance(fn, JITFunction) and not fn.noinline: + fn_node = fn.parse() + return ContainsReturnChecker(self.gscope).visit(fn_node) + return False + + def generic_visit(self, node) -> bool: + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.visit(item) + elif isinstance(value, ast.AST): + ret = ret or self.visit(value) + return ret + + def visit_Attribute(self, node: ast.Attribute) -> bool: + # If the left part is a name, it's possible that + # we call triton native function or a jit function from another module. + # If the left part is not a name, it must return a tensor or a constexpr + # whose methods do not contain return statements + # e.g., (tl.load(x)).to(y) + # So we only check if the expressions within value have return or not + if isinstance(node.value, ast.Name): + if node.value.id in self.gscope: + value = self.gscope[node.value.id] + fn = getattr(value, node.attr) + return self._visit_function(fn) + return False + return self.visit(node.value) + + def visit_Name(self, node: ast.Name) -> bool: + if type(node.ctx) == ast.Store: + return False + if node.id in self.gscope: + fn = self.gscope[node.id] + return self._visit_function(fn) + return False + + def visit_Return(self, node: ast.Return) -> bool: + return True + + def visit_Assign(self, node: ast.Assign) -> bool: + # There couldn't be an early return + # x = ... + return False + + def visit_AugAssign(self, node: ast.AugAssign) -> bool: + # There couldn't be an early return + # x += ... + return False + + def visit_Module(self, node: ast.Module) -> bool: + return self._visit_stmts(node.body) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> bool: + return self._visit_stmts(node.body) + + def visit_If(self, node: ast.If) -> bool: + # TODO: optimize the following case in which we actually don't have + # a return when static_cond is false: + # if dynamic_cond + # if static_cond + # func_with_return + # else + # func_without_return + ret = self._visit_stmts(node.body) + if node.orelse: + ret = ret or self._visit_stmts(node.orelse) + return ret + + def visit_IfExp(self, node: ast.IfExp) -> bool: + return self.visit(node.body) or self.visit(node.orelse) + + def visit_Call(self, node: ast.Call) -> bool: + return self.visit(node.func) + + +class CodeGenerator(ast.NodeVisitor): + def __init__(self, context, prototype, gscope, attributes, constants, function_name, + module=None, is_kernel=False, function_types: Optional[Dict] = None, + debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0): + self.context = context + self.builder = ir.builder(context) + self.file_name = file_name + # node.lineno starts from 1, so we need to subtract 1 + self.begin_line = begin_line - 1 + self.builder.set_loc(file_name, begin_line, 0) + # self.builder.arch = arch + self.module = self.builder.create_module() if module is None else module + self.function_ret_types = {} if function_types is None else function_types + self.prototype = prototype + self.gscope = gscope + self.lscope = dict() + self.attributes = attributes + self.constants = constants + self.function_name = function_name + self.is_kernel = is_kernel + self.last_node = None + self.debug = debug + self.noinline = noinline + self.scf_stack = [] + self.last_ret_type = None + # SSA-construction + # name => language.tensor + self.local_defs: Dict[str, tensor] = {} + self.global_uses: Dict[str, tensor] = {} + self.dereference_name: Callable[[str], Any] = self._define_name_lookup() + + builtin_namespace: Dict[str, Any] = {_.__name__: _ for _ in (range, float, int, isinstance, getattr)} + builtin_namespace.update(( + ('print', language.core.device_print), + ('min', language.minimum), + )) + + def _define_name_lookup(self): + def local_lookup(name: str, absent): + value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally + if value is not absent and name not in self.local_defs: + self.global_uses[name] = value + return value + + absent_marker = object() + + def name_lookup(name: str) -> Any: + absent = absent_marker + for lookup_function in local_lookup, self.gscope.get, self.builtin_namespace.get: + value = lookup_function(name, absent) + if value is not absent: + return value + raise NameError(f'{name} is not defined') + + return name_lookup + + def set_value(self, name: str, + value: Union[tensor, constexpr]) -> None: + ''' This function: + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) + 1. record local defined name (FIXME: should consider control flow) + 2. store tensor in self.lvalue + ''' + self.lscope[name] = value + self.local_defs[name] = value + + def _get_insertion_point_and_loc(self): + # XXX: this is a hack to get the location of the insertion point. + # The insertion point's location could be invalid sometimes, + # so we need to explicitly set the location + loc = self.builder.get_loc() + ip = self.builder.get_insertion_point() + return ip, loc + + def _set_insertion_point_and_loc(self, ip, loc): + self.builder.restore_insertion_point(ip) + self.builder.set_loc(loc) + + # + # AST visitor + # + def visit_compound_statement(self, stmts): + for stmt in stmts: + ret_type = self.visit(stmt) + if ret_type is not None and isinstance(stmt, ast.Return): + self.last_ret_type = ret_type + + def visit_Module(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_List(self, node): + ctx = self.visit(node.ctx) + assert ctx is None + elts = [self.visit(elt) for elt in node.elts] + return elts + + # By design, only non-kernel functions can return + def visit_Return(self, node): + ret_value = self.visit(node.value) + # ret_block = self.builder.create_block() + # post_ret_block = self.builder.create_block() + # self.builder.create_branch(ret_block) + # self.builder.set_insertion_point_to_end(ret_block) + if ret_value is None: + self.builder.ret([]) + ret_ty = None + elif isinstance(ret_value, tuple): + ret_values = [language.core._to_tensor(v, self.builder) for v in ret_value] + ret_types = [v.type for v in ret_values] + self.builder.ret([v.handle for v in ret_values]) + ret_ty = tuple(ret_types) + else: + ret = language.core._to_tensor(ret_value, self.builder) + self.builder.ret([ret.handle]) + ret_ty = ret.type + # self.builder.create_branch(post_ret_block) + # self.builder.set_insertion_point_to_end(post_ret_block) + return ret_ty + + def visit_FunctionDef(self, node): + arg_names, kwarg_names = self.visit(node.args) + # initialize defaults + for i, default_value in enumerate(node.args.defaults): + arg_node = node.args.args[-i - 1] + annotation = arg_node.annotation + name = arg_node.arg + st_target = ast.Name(id=name, ctx=ast.Store()) + if annotation is None: + init_node = ast.Assign(targets=[st_target], value=default_value) + else: + init_node = ast.AnnAssign(target=st_target, value=default_value, annotation=annotation) + self.visit(init_node) + # initialize function + visibility = "public" if self.is_kernel else "private" + fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline) + self.module.push_back(fn) + entry = fn.add_entry_block() + arg_values = [] + idx = 0 + for i, arg_name in enumerate(arg_names): + if i in self.constants: + cst = self.constants[i] + if not _is_constexpr(cst): + cst = constexpr(self.constants[i]) + arg_values.append(cst) + continue + else: + if i in self.attributes: + fn.set_arg_attr(idx, "tt.divisibility", self.attributes[i][1]) + arg_values.append(tensor(fn.args(idx), self.prototype.param_types[idx])) + idx += 1 + + insert_pt = self.builder.get_insertion_block() + for arg_name, arg_value in zip(arg_names, arg_values): + self.set_value(arg_name, arg_value) + self.builder.set_insertion_point_to_start(entry) + # visit function body + self.visit_compound_statement(node.body) + # finalize function + if self.last_ret_type is None: + self.builder.ret([]) + else: + # update return type + if isinstance(self.last_ret_type, tuple): + self.prototype.ret_types = list(self.last_ret_type) + fn.reset_type(self.prototype.to_ir(self.builder)) + else: + self.prototype.ret_types = [self.last_ret_type] + fn.reset_type(self.prototype.to_ir(self.builder)) + if insert_pt: + self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + fn.finalize() + + def visit_arguments(self, node): + arg_names = [] + for arg in node.args: + arg_names += [self.visit(arg)] + kwarg_names = self.visit(node.kwarg) + return arg_names, kwarg_names + + def visit_arg(self, node): + ast.NodeVisitor.generic_visit(self, node) + return node.arg + + def visit_AnnAssign(self, node): + # extract attributes + annotation = self.visit(node.annotation) + target = self.visit(node.target) + value = self.visit(node.value) + # constexpr + if annotation == constexpr: + if target in self.lscope: + raise ValueError(f'{target} is already defined.' + f' constexpr cannot be reassigned.') + if not _is_constexpr(value): + value = constexpr(value) + self.lscope[target] = value + return self.lscope[target] + # default: call visit_Assign + return self.visit_Assign(node) + + def visit_Assign(self, node): + _names = [] + for target in node.targets: + _names += [self.visit(target)] + if len(_names) > 1: + raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple assignment is not supported.") + names = _names[0] + values = self.visit(node.value) + if not isinstance(names, tuple): + names = [names] + if not isinstance(values, tuple): + values = [values] + native_nontensor_types = (language.dtype, ) + for name, value in zip(names, values): + # by default, constexpr are assigned into python variable + value = _unwrap_if_constexpr(value) + if value is not None and \ + not _is_triton_tensor(value) and \ + not isinstance(value, native_nontensor_types): + value = language.core._to_tensor(value, self.builder) + self.set_value(name, value) + + def visit_AugAssign(self, node): + name = node.target.id + lhs = ast.Name(id=name, ctx=ast.Load()) + rhs = ast.BinOp(lhs, node.op, node.value) + assign = ast.Assign(targets=[node.target], value=rhs) + self.visit(assign) + return self.dereference_name(name) + + def visit_Name(self, node): + if type(node.ctx) == ast.Store: + return node.id + return self.dereference_name(node.id) + + def visit_Store(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Load(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Tuple(self, node): + args = [self.visit(x) for x in node.elts] + return tuple(args) + + def _apply_binary_method(self, method_name, lhs, rhs): + # TODO: raise something meaningful if getattr fails below, esp for reverse method + if _is_triton_tensor(lhs): + return getattr(lhs, method_name)(rhs, _builder=self.builder) + if _is_triton_tensor(rhs): + reverse_method_name = re.sub(r"__(.*)__", r"__r\1__", method_name) + return getattr(rhs, reverse_method_name)(lhs, _builder=self.builder) + return getattr(lhs, method_name)(rhs) + + def visit_BinOp(self, node): + lhs = self.visit(node.left) + rhs = self.visit(node.right) + method_name = self._method_name_for_bin_op.get(type(node.op)) + if method_name is None: + raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { + ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__', + ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__', + } + + def visit_then_else_blocks(self, node, liveins, then_block, else_block): + # then block + self.builder.set_insertion_point_to_start(then_block) + self.visit_compound_statement(node.body) + then_block = self.builder.get_insertion_block() + then_defs = self.local_defs.copy() + # else block + else_defs = {} + if node.orelse: + self.builder.set_insertion_point_to_start(else_block) + self.lscope = liveins.copy() + self.local_defs = {} + self.visit_compound_statement(node.orelse) + else_defs = self.local_defs.copy() + else_block = self.builder.get_insertion_block() + + # update block arguments + names = [] + ret_types = [] + ir_ret_types = [] + # variables in livein whose value is updated in `if` + for name in liveins: + # check type + for defs, block_name in [(then_defs, 'then'), (else_defs, 'else')]: + if name in defs: + assert defs[name].type == liveins[name].type,\ + f'initial value for `{name}` is of type {liveins[name].type}, '\ + f'but the {block_name} block redefines it as {defs[name].type}' + if name in then_defs or name in else_defs: + names.append(name) + ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) + ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type()) + # variable defined in then but not in else + if name in then_defs and name not in else_defs: + else_defs[name] = liveins[name] + # variable defined in else but not in then + if name in else_defs and name not in then_defs: + then_defs[name] = liveins[name] + # variables that are both in then and else but not in liveins + # TODO: could probably be cleaned up + for name in then_defs.keys() & else_defs.keys(): + if name in names: + continue + then_ty = then_defs[name].type + else_ty = else_defs[name].type + assert then_ty == else_ty,\ + f'mismatched type for {name} between then block ({then_ty}) '\ + f'and else block ({else_ty})' + names.append(name) + ret_types.append(then_ty) + ir_ret_types.append(then_defs[name].handle.get_type()) + + return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types + + def visit_if_top_level(self, cond, node): + has_endif_block = True + with enter_sub_region(self) as sr: + liveins, ip_block = sr + then_block = self.builder.create_block() + else_block = self.builder.create_block() + # create basic-block after conditional + endif_block = self.builder.create_block() + # create branch + self.builder.set_insertion_point_to_end(ip_block) + self.builder.create_cond_branch(cond.handle, then_block, else_block) + # visit then and else blocks + then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # then terminator + self.builder.set_insertion_point_to_end(then_block) + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) + # else terminator + self.builder.set_insertion_point_to_end(else_block) + if not else_block.has_terminator() and has_endif_block: + self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) + + # TODO: refactor + def visit_if_scf(self, cond, node): + with enter_sub_region(self) as sr: + liveins, _ = sr + ip, last_loc = self._get_insertion_point_and_loc() + then_block = self.builder.create_block() + else_block = self.builder.create_block() if node.orelse else None + then_defs, else_defs, then_block, else_block, names, ret_types, _ = \ + self.visit_then_else_blocks(node, liveins, then_block, else_block) + # create if op + self._set_insertion_point_and_loc(ip, last_loc) + if_op = self.builder.create_if_op([ty.to_ir(self.builder) for ty in ret_types], cond.handle, True) + then_block.merge_block_before(if_op.get_then_block()) + self.builder.set_insertion_point_to_end(if_op.get_then_block()) + if len(names) > 0: + self.builder.create_yield_op([then_defs[n].handle for n in names]) + if not node.orelse: + else_block = if_op.get_else_block() + else: + else_block.merge_block_before(if_op.get_else_block()) + self.builder.set_insertion_point_to_end(if_op.get_else_block()) + if len(names) > 0: + self.builder.create_yield_op([else_defs[n].handle for n in names]) + # update values + for i, name in enumerate(names): + new_tensor = language.core.tensor(if_op.get_result(i), ret_types[i]) + self.set_value(name, new_tensor) + + def visit_If(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + contains_return = ContainsReturnChecker(self.gscope).visit(node) + if self.scf_stack and contains_return: + raise UnsupportedLanguageConstruct(None, node, + "Cannot have `return` statements inside `while` or `for` statements in triton") + elif self.scf_stack or not contains_return: + self.visit_if_scf(cond, node) + else: + self.visit_if_top_level(cond, node) + else: + cond = _unwrap_if_constexpr(cond) + if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks + raise UnsupportedLanguageConstruct( + None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) + if cond: + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.orelse) + + def visit_IfExp(self, node): + cond = self.visit(node.test) + if _is_triton_tensor(cond): + cond = cond.to(language.int1, _builder=self.builder) + if _unwrap_if_constexpr(cond): + return self.visit(node.body) + else: + return self.visit(node.orelse) + + def visit_Pass(self, node): + pass + + def visit_Compare(self, node): + if not (len(node.comparators) == 1 and len(node.ops) == 1): + raise UnsupportedLanguageConstruct(None, node, "simultaneous multiple comparison is not supported") + lhs = _unwrap_if_constexpr(self.visit(node.left)) + rhs = _unwrap_if_constexpr(self.visit(node.comparators[0])) + if type(node.ops[0]) == ast.Is: + return constexpr(lhs is rhs) + if type(node.ops[0]) == ast.IsNot: + return constexpr(lhs is not rhs) + method_name = self._method_name_for_comp_op.get(type(node.ops[0])) + if method_name is None: + raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { + ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' + } + + def visit_UnaryOp(self, node): + op = self.visit(node.operand) + fn = self._method_name_for_unary_op.get(type(node.op)) + if fn is None: + raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + if _is_triton_tensor(op): + return getattr(op, fn)(_builder=self.builder) + return getattr(op, fn)() + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'} + + def visit_While(self, node): + with enter_sub_region(self) as sr: + liveins, insert_block = sr + + # loop body (the after region) + # loop_block = self.builder.create_block() + dummy = self.builder.create_block() + self.builder.set_insertion_point_to_start(dummy) + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + + # collect loop-carried values + names = [] + ret_types = [] + init_args = [] + for name in loop_defs: + if name in liveins: + # We should not def new constexpr + assert _is_triton_tensor(loop_defs[name]) + assert _is_triton_tensor(liveins[name]) + assert loop_defs[name].type == liveins[name].type + # these are loop-carried values + names.append(name) + ret_types.append(loop_defs[name].type) + init_args.append(liveins[name]) + + self.builder.set_insertion_point_to_end(insert_block) + while_op = self.builder.create_while_op([ty.to_ir(self.builder) for ty in ret_types], + [arg.handle for arg in init_args]) + # merge the condition region + before_block = self.builder.create_block_with_parent(while_op.get_before(), + [ty.to_ir(self.builder) for ty in ret_types]) + self.builder.set_insertion_point_to_start(before_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(before_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + cond = self.visit(node.test) + self.builder.set_insertion_point_to_end(before_block) + # create ConditionOp: e.g., scf.condition(%cond) %arg0, %arg1, ... + self.builder.create_condition_op(cond.handle, [before_block.arg(i) for i in range(len(init_args))]) + # merge the loop body + after_block = self.builder.create_block_with_parent(while_op.get_after(), + [ty.to_ir(self.builder) for ty in ret_types]) + + # generate loop body + self.builder.set_insertion_point_to_start(after_block) + for i, name in enumerate(names): + self.lscope[name] = language.core.tensor(after_block.arg(i), ret_types[i]) + self.local_defs[name] = self.lscope[name] + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + loop_defs = self.local_defs + yields = [] + for name in loop_defs: + if name in liveins: + yields.append(loop_defs[name]) + self.builder.create_yield_op([y.handle for y in yields]) + + # update global uses in while_op + for i, name in enumerate(names): + after_block.replace_use_in_block_with(init_args[i].handle, after_block.arg(i)) + + # WhileOp defines new values, update the symbol table (lscope, local_defs) + for i, name in enumerate(names): + new_def = language.core.tensor(while_op.get_result(i), ret_types[i]) + self.lscope[name] = new_def + self.local_defs[name] = new_def + + for stmt in node.orelse: + assert False, "Not implemented" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Subscript(self, node): + assert node.ctx.__class__.__name__ == "Load" + lhs = self.visit(node.value) + slices = self.visit(node.slice) + if _is_triton_tensor(lhs): + return lhs.__getitem__(slices, _builder=self.builder) + return lhs[slices] + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) + iter_args = [self.visit(arg) for arg in node.iter.args] + if IteratorClass == language.static_range: + iterator = IteratorClass(*iter_args) + static_range = range(iterator.start.value, + iterator.end.value, + iterator.step.value) + for i in static_range: + self.lscope[node.target.id] = constexpr(i) + self.visit_compound_statement(node.body) + for stmt in node.orelse: + ast.NodeVisitor.generic_visit(self, stmt) + return + + if IteratorClass is not range: + raise RuntimeError('Only `range` and `static_range` iterators are currently supported') + + # visit iterator arguments + # note: only `range` iterator is supported now + # collect lower bound (lb), upper bound (ub), and step + lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0)) + ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0]) + step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1)) + # handle negative constant step (not supported by scf.for in MLIR) + negative_step = False + if _is_constexpr(step) and step.value < 0: + step = constexpr(-step.value) + negative_step = True + lb, ub = ub, lb + lb = language.core._to_tensor(lb, self.builder) + ub = language.core._to_tensor(ub, self.builder) + step = language.core._to_tensor(step, self.builder) + # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") + iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) + iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) + iv_ir_type = iv_type.to_ir(self.builder) + iv_is_signed = iv_type.int_signedness == language.core.dtype.SIGNEDNESS.SIGNED + # lb/ub/step might be constexpr, we need to cast them to tensor + lb = lb.handle + ub = ub.handle + step = step.handle + # ForOp can only accept IndexType as lb/ub/step. Cast integer to Index + lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed) + ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed) + step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed) + # Create placeholder for the loop induction variable + iv = self.builder.create_undef(iv_ir_type) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + with enter_sub_region(self) as sr: + liveins, insert_block = sr + ip, last_loc = self._get_insertion_point_and_loc() + + # create loop body block + block = self.builder.create_block() + self.builder.set_insertion_point_to_start(block) + # dry visit loop body + self.scf_stack.append(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + block.erase() + + # If a variable (name) is defined in both its parent & itself, then it's + # a loop-carried variable. (They must be of the same type) + init_args = [] + yields = [] + names = [] + for name in self.local_defs: + if name in liveins: + assert _is_triton_tensor(self.local_defs[name]), f'{name} is not tensor' + assert _is_triton_tensor(liveins[name]) + assert self.local_defs[name].type == liveins[name].type,\ + f'Loop-carried variable {name} has initial type {liveins[name].type} '\ + f'but is re-assigned to {self.local_defs[name].type} in loop! '\ + f'Please make sure that the type stays consistent.' + + names.append(name) + init_args.append(language.core._to_tensor(liveins[name], self.builder)) + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create ForOp + self._set_insertion_point_and_loc(ip, last_loc) + for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) + + self.scf_stack.append(node) + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_body(0).arg(i + 1), yields[i].type)) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yields = [] + for name in self.local_defs: + if name in liveins: + yields.append(language.core._to_tensor(self.local_defs[name], self.builder)) + + # create YieldOp + if len(yields) > 0: + self.builder.create_yield_op([y.handle for y in yields]) + for_op_region = for_op.get_body(0).get_parent() + assert for_op_region.size() == 1, "We use SCF, so the loop body should only have one block" + + # update induction variable with actual value, and replace all uses + self.builder.set_insertion_point_to_start(for_op.get_body(0)) + iv = for_op.get_induction_var() + if negative_step: + iv = self.builder.create_sub(ub, iv) + iv = self.builder.create_add(iv, lb) + self.lscope[node.target.id].handle.replace_all_uses_with(iv) + self.set_value(node.target.id, language.core.tensor(iv, iv_type)) + + # update lscope & local_defs (ForOp defines new values) + for i, name in enumerate(names): + self.set_value(name, language.core.tensor(for_op.get_result(i), yields[i].type)) + + for stmt in node.orelse: + assert False, "Don't know what to do with else after for" + ast.NodeVisitor.generic_visit(self, stmt) + + def visit_Slice(self, node): + lower = self.visit(node.lower) + upper = self.visit(node.upper) + step = self.visit(node.step) + return slice(lower, upper, step) + + def visit_Index(self, node): + return self.visit(node.value) + + def visit_keyword(self, node) -> Tuple[str, Any]: + return node.arg, self.visit(node.value) + + def visit_Assert(self, node) -> Any: + if not self.debug: + return + test = self.visit(node.test) + msg = self.visit(node.msg) + # Convert assert to triton's device_assert which happens on the device + return language.core.device_assert(test, msg, _builder=self.builder) + + def call_JitFunction(self, fn: JITFunction, args, kwargs): + args = inspect.getcallargs(fn.fn, *args, **kwargs) + args = [args[name] for name in fn.arg_names] + args = [arg if _is_triton_tensor(arg) + else constexpr(arg) for arg in args] + # generate function def + attributes = dict() + constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] + constants = {i: args[i] for i in constexprs} + # generate call + args = [None if i in constexprs else arg for i, arg in enumerate(args)] + arg_vals = [arg.handle for arg in args if arg is not None] + arg_types = [arg.type for arg in args if arg is not None] + fn_name = mangle_fn(fn.__name__, arg_types, constants) + # generate function def if necessary + if not self.module.has_function(fn_name): + prototype = language.function_type([], arg_types) + gscope = sys.modules[fn.fn.__module__].__dict__ + # If the callee is not set, we use the same debug setting as the caller + debug = self.debug if fn.debug is None else fn.debug + file_name, begin_line = _get_fn_file_line(fn) + generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, + function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline, + file_name=file_name, begin_line=begin_line) #, arch=self.builder.arch) + generator.visit(fn.parse()) + callee_ret_type = generator.last_ret_type + self.function_ret_types[fn_name] = callee_ret_type + else: + callee_ret_type = self.function_ret_types[fn_name] + symbol = self.module.get_function(fn_name) + call_op = self.builder.call(symbol, arg_vals) + if call_op.get_num_results() == 0 or callee_ret_type is None: + return None + elif call_op.get_num_results() == 1: + return tensor(call_op.get_result(0), callee_ret_type) + else: + # should return a tuple of tl.tensor + results = [] + for i in range(call_op.get_num_results()): + results.append(tensor(call_op.get_result(i), callee_ret_type[i])) + return tuple(results) + + def visit_Call(self, node): + fn = _unwrap_if_constexpr(self.visit(node.func)) + + static_implementation = self.statically_implemented_functions.get(fn) + if static_implementation is not None: + return static_implementation(self, node) + + kws = dict(self.visit(keyword) for keyword in node.keywords) + args = [self.visit(arg) for arg in node.args] + if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if not self.debug: + return + if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) + return self.call_JitFunction(fn, args, kws) + if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): + extra_kwargs = dict(_builder=self.builder) + sig = inspect.signature(fn) + if '_generator' in sig.parameters: + extra_kwargs['_generator'] = self + return fn(*args, **extra_kwargs, **kws) + if fn in self.builtin_namespace.values(): + args = map(_unwrap_if_constexpr, args) + return fn(*args, **kws) + + def visit_Constant(self, node): + return constexpr(node.value) + + def visit_BoolOp(self, node: ast.BoolOp): + if len(node.values) != 2: + raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + lhs = self.visit(node.values[0]) + rhs = self.visit(node.values[1]) + method_name = self._method_name_for_bool_op.get(type(node.op)) + if method_name is None: + raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} + + if sys.version_info < (3, 8): + def visit_NameConstant(self, node): + return constexpr(node.value) + + def visit_Num(self, node): + return constexpr(node.n) + + def visit_Str(self, node): + return constexpr(ast.literal_eval(node)) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + if _is_triton_tensor(lhs): + if node.attr == "T": + return language.semantic.trans(lhs, builder=self.builder) + return getattr(lhs, node.attr) + + def visit_Expr(self, node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_NoneType(self, node): + return None + + def visit_JoinedStr(self, node): + values = list(node.values) + for i, value in enumerate(values): + if isinstance(value, ast.Constant): + values[i] = str(value.value) + elif isinstance(value, ast.FormattedValue): + conversion_code = value.conversion + evaluated = self.visit(value.value) + if not _is_constexpr(evaluated): + raise UnsupportedLanguageConstruct( + None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated))) + values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) + else: + raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) + return ''.join(values) + + def visit(self, node): + if node is None: + return + with warnings.catch_warnings(): + # The ast library added visit_Constant and deprecated some other + # methods but we can't move to that without breaking Python 3.6 and 3.7. + warnings.simplefilter("ignore", DeprecationWarning) # python 3.9 + warnings.simplefilter("ignore", PendingDeprecationWarning) # python 3.8 + self.last_node = node + last_loc = self.builder.get_loc() + if hasattr(node, 'lineno') and hasattr(node, 'col_offset'): + self.builder.set_loc(self.file_name, self.begin_line + node.lineno, node.col_offset) + last_loc = self.builder.get_loc() + ret = super().visit(node) + # Reset the location to the last one before the visit + if last_loc: + self.builder.set_loc(last_loc) + return ret + + def generic_visit(self, node): + raise UnsupportedLanguageConstruct(None, node, "unsupported AST node type: {}".format(type(node).__name__)) + + def execute_static_print(self, node: ast.Call) -> None: + # TODO: too simplistic? Perhaps do something else with non-constexpr + + kws = {name: _unwrap_if_constexpr(value) for name, value in (self.visit(keyword) for keyword in node.keywords)} + args = [_unwrap_if_constexpr(self.visit(arg)) for arg in node.args] + print(*args, **kws) + + def execute_static_assert(self, node: ast.Call) -> None: + arg_count = len(node.args) + if not (0 < arg_count <= 2) or len(node.keywords): + raise TypeError("`static_assert` requires one or two positional arguments only") + + passed = self.visit(node.args[0]) + if not isinstance(passed, bool): + raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values") + if not passed: + if arg_count == 1: + message = "" + else: + try: + message = self.visit(node.args[1]) + except Exception as e: + message = "" + + raise CompileTimeAssertionFailure(None, node, _unwrap_if_constexpr(message)) + return None + + statically_implemented_functions: Dict[object, Callable[[ast.Call], Any]] = { + language.core.static_assert: execute_static_assert, + language.core.static_print: execute_static_print, + } + + +def str_to_ty(name): + if name[0] == "*": + ty = str_to_ty(name[1:]) + return language.pointer_type(ty) + tys = { + "fp8e5": language.float8e5, + "fp8e4": language.float8e4, + "fp16": language.float16, + "bf16": language.bfloat16, + "fp32": language.float32, + "fp64": language.float64, + "i1": language.int1, + "i8": language.int8, + "i16": language.int16, + "i32": language.int32, + "i64": language.int64, + "u8": language.uint8, + "u16": language.uint16, + "u32": language.uint32, + "u64": language.uint64, + "B": language.int1, + } + return tys[name] + + +def kernel_suffix(signature, specialization): + # suffix format: + # <'c' if equal to 1><'d' if divisible by 16> + suffix = '' + for i, _ in enumerate(signature): + suffix += str(i) + if i in specialization.equal_to_1: + suffix += 'c' + if i in specialization.divisible_by_16: + suffix += 'd' + return suffix + + +def ast_to_ttir(fn, signature, specialization, constants, debug): + # canonicalize signature + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + context = ir.context() + context.load_triton() + # create kernel prototype + cst_key = lambda i: fn.arg_names.index(i) if isinstance(i, str) else i + constants = {cst_key(key): value for key, value in constants.items()} + # visit kernel AST + gscope = fn.__globals__.copy() + function_name = '_'.join([fn.__name__, kernel_suffix(signature.values(), specialization)]) + tys = list(signature.values()) + new_constants = {k: True if k in tys and tys[k] == "i1" else 1 for k in specialization.equal_to_1} + new_attrs = {k: ("multiple_of", 16) for k in specialization.divisible_by_16} + all_constants = constants.copy() + all_constants.update(new_constants) + arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants] + file_name, begin_line = _get_fn_file_line(fn) + + prototype = language.function_type([], arg_types) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, + function_name=function_name, attributes=new_attrs, + is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line) + # arch=arch) + try: + generator.visit(fn.parse()) + except CompilationError as e: + if e.src is None: + e.set_source_code(fn.src) + raise + except Exception as e: + node = generator.last_node + if node is None: + raise + raise CompilationError(fn.src, node, repr(e)) from e + ret = generator.module + # module takes ownership of the context + ret.context = context + return ret diff --git a/pkgs/triton/compiler/compiler.py b/pkgs/triton/compiler/compiler.py new file mode 100644 index 0000000..c5e5635 --- /dev/null +++ b/pkgs/triton/compiler/compiler.py @@ -0,0 +1,631 @@ +from __future__ import annotations + +import functools +import hashlib +import json +import os +import re +import subprocess +import tempfile +from collections import namedtuple +from pathlib import Path +from typing import Any, Tuple + +import triton +import triton._C.libtriton.triton as _triton +from ..runtime import driver +# TODO: runtime.errors +from ..runtime.autotuner import OutOfResources +from ..runtime.cache import get_cache_manager +from ..tools.disasm import extract +from .code_generator import ast_to_ttir +from .make_launcher import make_stub + + +def is_corex(): + import torch + return hasattr(torch, "corex") and torch.corex == True + +CUDA_DEFAULT_WARP_SIZE = 64 if is_corex() else 32 + +def inline_triton_ir(mod): + pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_inliner_pass() + pm.run(mod) + return mod + + +def ttir_compute_capability_rewrite(mod, arch): + # For hardware without support, we must rewrite all load/store + # with block (tensor) pointers into tensors of pointers + pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() + if _is_cuda(arch): + pm.add_rewrite_tensor_pointer_pass(arch) + pm.run(mod) + return mod + + +def optimize_ttir(mod, arch): + mod = inline_triton_ir(mod) + mod = ttir_compute_capability_rewrite(mod, arch) + pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_inliner_pass() + pm.add_triton_combine_pass() + pm.add_canonicalizer_pass() + pm.add_cse_pass() + pm.add_licm_pass() + pm.add_symbol_dce_pass() + pm.run(mod) + return mod + + +def ttir_to_ttgir(mod, num_warps, warpsize): + pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize) + pm.run(mod) + return mod + + +def optimize_ttgir(mod, num_stages, arch, use_sme = 0): + pm = _triton.ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_tritongpu_coalesce_pass() + pm.add_tritongpu_remove_layout_conversions_pass() + if _is_cuda(arch): + pm.add_tritongpu_accelerate_matmul_pass(arch, use_sme) + # TODO change interface of accelerate_matmul_pass + if is_hip() and gpu_has_mfma(): + pm.add_tritongpu_accelerate_matmul_pass(80) + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + if is_corex(): + pm.add_tritongpu_matmul_smeload_pass(arch) #BI 70 MR > 71,only MR support sme + pm.add_tritongpu_remove_layout_conversions_pass() + # TODO enable this pass for AMD GPU when it is ready + if not is_hip(): + pm.add_tritongpu_pipeline_pass(num_stages) + if not is_corex(): + pm.add_tritongpu_prefetch_pass() + pm.add_tritongpu_optimize_dot_operands_pass() + pm.add_tritongpu_remove_layout_conversions_pass() + pm.add_tritongpu_decompose_conversions_pass() + pm.add_tritongpu_reorder_instructions_pass() + pm.add_cse_pass() + pm.add_symbol_dce_pass() + pm.run(mod) + return mod + + +def _add_external_libs(mod, libs): + for name, path in libs.items(): + if len(name) == 0 or len(path) == 0: + return + _triton.add_external_libs(mod, list(libs.keys()), list(libs.values())) + + +def ttgir_to_llir(mod, extern_libs, arch): + if extern_libs: + _add_external_libs(mod, extern_libs) + # TODO: separate tritongpu_to_llvmir for different backends + if _is_cuda(arch): + return _triton.translate_triton_gpu_to_llvmir(mod, arch, False) + else: + return _triton.translate_triton_gpu_to_llvmir(mod, 0, True) + + +# PTX translation + +@functools.lru_cache() +def ptx_get_version(cuda_version) -> int: + ''' + Get the highest PTX version supported by the current CUDA driver. + ''' + assert isinstance(cuda_version, str) + major, minor = map(int, cuda_version.split('.')) + if major == 12: + return 80 + minor + if major == 11: + return 70 + minor + if major == 10: + return 63 + minor + raise RuntimeError("Triton only support CUDA 10.0 or higher") + + +@functools.lru_cache() +def path_to_ptxas(): + base_dir = os.path.join(os.path.dirname(__file__), os.pardir) + paths = [ + os.environ.get("TRITON_PTXAS_PATH", ""), + os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas") + ] + + for ptxas in paths: + if os.path.exists(ptxas) and os.path.isfile(ptxas): + result = subprocess.check_output([ptxas, "--version"], stderr=subprocess.STDOUT) + if result is not None: + version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) + if version is not None: + return ptxas, version.group(1) + raise RuntimeError("Cannot find ptxas") + + +def llir_to_cubin(mod: Any, arch: int): + ''' + Compile LLVM module to cubin. + :param mod: a LLVM module + :return: str + ''' + return _triton.translate_llvmir_to_cubin(mod, arch) + + +def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str: + ''' + Translate TritonGPU module to PTX code. + :param mod: a TritonGPU dialect module + :return: PTX code + ''' + if ptx_version is None: + _, cuda_version = path_to_ptxas() + ptx_version = ptx_get_version(cuda_version) + return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version) + + +def ptx_to_cubin(ptx: str, arch: int): + ''' + Compile TritonGPU module to cubin. + :param ptx: ptx code + :param compute_capability: compute capability + :return: str + ''' + ptxas, _ = path_to_ptxas() + return _triton.compile_ptx_to_cubin(ptx, ptxas, arch) + + +# AMDGCN translation + +def get_amdgcn_bitcode_paths(arch): + gpu_arch_agnostic_bitcode_libraries = ["opencl.bc", + "ocml.bc", + "ockl.bc", + "oclc_finite_only_off.bc", + "oclc_daz_opt_off.bc", + "oclc_correctly_rounded_sqrt_on.bc", + "oclc_unsafe_math_off.bc", + "oclc_wavefrontsize64_on.bc", + "oclc_abi_version_400.bc",] + + gfx_arch = arch[1] + gfx_arch_id = re.search('gfx(\\w+)', gfx_arch).group(1).strip() + + gpu_arch_specific_bitcode_library = 'oclc_isa_version_' + gfx_arch_id + ".bc" + bitcode_path_dir = os.path.join(Path(__file__).parent.parent.resolve(), "third_party/rocm/lib/bitcode/") + + amdgcn_bitcode_paths = {} + i = 0 + for bc_lib in gpu_arch_agnostic_bitcode_libraries: + bc_path = bitcode_path_dir + bc_lib + if os.path.exists(bc_path): + amdgcn_bitcode_paths['library_' + str(i)] = bc_path + i += 1 + bc_gfx_path = bitcode_path_dir + gpu_arch_specific_bitcode_library + if os.path.exists(bc_gfx_path): + amdgcn_bitcode_paths['library_' + str(i)] = bc_gfx_path + + return amdgcn_bitcode_paths + + +def get_amdgpu_arch_fulldetails(): + """ + get the amdgpu fulll ISA details for compiling: + i.e., arch_triple: amdgcn-amd-amdhsa; arch_name: gfx906; arch_features: sramecc+:xnack- + """ + try: + # TODO: package rocm.cc with Triton + arch_info = _triton.get_arch_info() + warpsize = _triton.get_warp_size() + gfx_arch_details = re.search('amd.*', arch_info).group(0).strip().split('--') + arch_triple = gfx_arch_details[0] + arch_name_features = gfx_arch_details[1].split(':') + arch_name = arch_name_features[0] + arch_features = "" + + if (len(arch_name_features) == 3): + arch_features = "+" + re.search('\\w+', arch_name_features[1]).group(0) + ","\ + "-" + re.search('\\w+', arch_name_features[2]).group(0) + return [arch_triple, arch_name, arch_features, warpsize] + except BaseException: + return None + + +def llir_to_amdgcn_and_hsaco(mod: Any, gfx_arch: str, gfx_triple: str, gfx_features: str) -> Tuple[str, str]: + ''' + Translate TritonGPU module to HSACO code based on full details of gpu architecture. + :param mod: a TritonGPU dialect module + :return: + - AMDGCN code + - Path to HSACO object + ''' + return _triton.translate_llvmir_to_hsaco(mod, gfx_arch, gfx_triple, gfx_features) + + +# ------------------------------------------------------------------------------ +# compiler +# ------------------------------------------------------------------------------ +def get_kernel_name(src: str, pattern: str, llir: bool = False) -> str: + ''' + Get kernel name from PTX code. + This Kernel name is required when launching the kernel. + ''' + # There is a name mangling in PTX codegen, so the original kernel names in Triton IR are not available in PTX/cubin. + assert src + for line in src.split('\n'): + line = line.strip() + if line.startswith(pattern): + if not llir: + return line.split()[-1] + return line.split("(")[0].split("@")[-1] + + +def convert_type_repr(x): + match = re.search(r'!tt\.ptr<(.*)>', x) + if match is not None: + return '*' + convert_type_repr(match.group(1)) + return x + + +def make_hash(fn, arch, **kwargs): + if isinstance(fn, triton.runtime.JITFunction): + configs = kwargs["configs"] + signature = kwargs["signature"] + constants = kwargs.get("constants", dict()) + num_warps = kwargs.get("num_warps", 4) + num_stages = kwargs.get("num_stages", 3) + debug = kwargs.get("debug", False) + use_sme = kwargs.get("use_sme", 0) + # Get unique key for the compiled code + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) + configs_key = [get_conf_key(conf) for conf in configs] + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}-{use_sme}" + return hashlib.md5(key.encode("utf-8")).hexdigest() + assert isinstance(fn, str) + return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() + + +# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, +# and any following whitespace +# - (public\s+)? : optionally match the keyword public and any following whitespace +# - (@\w+) : match an @ symbol followed by one or more word characters +# (letters, digits, or underscores), and capture it as group 1 (the function name) +# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing +# zero or more arguments separated by commas, and capture it as group 2 (the argument list) +mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' +ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" +prototype_pattern = { + "ttir": mlir_prototype_pattern, + "ttgir": mlir_prototype_pattern, + "ptx": ptx_prototype_pattern, +} + +mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?' +ptx_arg_type_pattern = r"\.param\s+\.(\w+)" +arg_type_pattern = { + "ttir": mlir_arg_type_pattern, + "ttgir": mlir_arg_type_pattern, + "ptx": ptx_arg_type_pattern, +} + +ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' + + +def _get_jsonable_constants(constants): + def _is_jsonable(x): + try: + json.dumps(x) + return True + except (TypeError, OverflowError): + return False + serialized_constants = {} + for constant in constants: + if _is_jsonable(constants[constant]): + serialized_constants[constant] = constants[constant] + return serialized_constants + + +def parse_mlir_module(path, context): + module = _triton.ir.parse_mlir_module(path, context) + # module takes ownership of the context + module.context = context + return module + + +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"], defaults=[set(), set()]) + + +# TODO: architecture descriptor class +def _is_cuda(arch): + return isinstance(arch, int) + +def is_hip(): + try: + import torch + except ImportError: + raise ImportError("Triton requires PyTorch to be installed") + return torch.version.hip is not None + +from ..language.semantic import gpu_has_mfma + +def get_architecture_descriptor(capability): + try: + import torch + except ImportError: + raise ImportError("Triton requires PyTorch to be installed") + if capability is None: + if torch.version.hip is None: + device = triton.runtime.jit.get_current_device() + capability = triton.runtime.jit.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + else: + capability = get_amdgpu_arch_fulldetails() + return capability + + +def add_rocm_stages(arch, extern_libs, stages): + extern_libs.update(get_amdgcn_bitcode_paths(arch)) + + for key in list(extern_libs): + if extern_libs[key] == '' or extern_libs[key] is None: + extern_libs.pop(key) + + gfx_arch_full_details = arch + gfx_arch = os.environ.get('MI_GPU_ARCH', gfx_arch_full_details[1]) + if gfx_arch is None: + raise RuntimeError('gfx_arch is None (not specified)') + stages["amdgcn"] = (lambda path: Path(path).read_text(), + lambda src: llir_to_amdgcn_and_hsaco(src, gfx_arch, + gfx_arch_full_details[0], + gfx_arch_full_details[2])) + + +def add_cuda_stages(arch, extern_libs, stages): + stages["ptx"] = (lambda path: Path(path).read_text(), + lambda src: llir_to_ptx(src, arch)) + stages["cubin"] = (lambda path: Path(path).read_bytes(), + lambda src: ptx_to_cubin(src, arch)) + + +def add_iluvatar_stages(arch, extern_libs, stages): + stages["cubin"] = (lambda path: Path(path).read_bytes(), + lambda src: llir_to_cubin(src, arch)) + + +def compile(fn, **kwargs): + if is_hip(): + capability = None + else: + capability = kwargs.get("cc", None) + arch = get_architecture_descriptor(capability) + is_cuda = _is_cuda(arch) + warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch[3] + context = _triton.ir.context() + asm = dict() + constants = kwargs.get("constants", dict()) + num_warps = kwargs.get("num_warps", 4) + num_stages = kwargs.get("num_stages", 3 if is_cuda and arch >= 75 else 2) + extern_libs = kwargs.get("extern_libs", dict()) + use_sme = kwargs.get("use_sme", 0) + if extern_libs is None: + extern_libs = dict() + debug = kwargs.get("debug", False) + # build compilation stages + stages = dict() + stages["ast"] = (lambda path: fn, None) + stages["ttir"] = (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug), arch)) + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, warp_size), num_stages, arch, use_sme)) + stages["llir"] = (lambda path: Path(path).read_text(), + lambda src: ttgir_to_llir(src, extern_libs, arch)) + if is_cuda: + if is_corex(): + add_iluvatar_stages(arch, extern_libs, stages) + else: + add_cuda_stages(arch, extern_libs, stages) + else: + add_rocm_stages(arch, extern_libs, stages) + + # find out the signature of the function + if isinstance(fn, triton.runtime.JITFunction): + configs = kwargs.get("configs", None) + signature = kwargs["signature"] + if configs is None: + configs = [instance_descriptor()] + assert len(configs) == 1 + kwargs["configs"] = configs + name = fn.__name__ + first_stage = 0 + if isinstance(signature, str): + signature = {k: v.strip() for k, v in enumerate(signature.split(","))} + kwargs["signature"] = signature + else: + assert isinstance(fn, str) + _, ir = os.path.basename(fn).split(".") + src = Path(fn).read_text() + import re + match = re.search(prototype_pattern[ir], src, re.MULTILINE) + name, signature = match.group(1), match.group(2) + types = re.findall(arg_type_pattern[ir], signature) + if ir == 'ttgir': + num_warps_matches = re.findall(ttgir_num_warps_pattern, src) + assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" + assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" + num_warps = int(num_warps_matches[0]) + param_tys = [convert_type_repr(ty) for ty in types] + signature = {k: v for k, v in enumerate(param_tys)} + first_stage = list(stages.keys()).index(ir) + + # cache manager + so_path = make_stub(name, signature, constants) + # create cache manager + fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs)) + # determine name and extension type of provided function + if isinstance(fn, triton.runtime.JITFunction): + name, ext = fn.__name__, "ast" + else: + name, ext = os.path.basename(fn).split(".") + + # load metadata if any + metadata = None + metadata_filename = f"{name}.json" + + # The group is addressed by the metadata + metadata_group = fn_cache_manager.get_group( + metadata_filename + ) or {} + + metadata_path = metadata_group.get(metadata_filename) + + if metadata_path is not None: + with open(metadata_path) as f: + metadata = json.load(f) + else: + metadata = {"num_warps": num_warps, + "warp_size": warp_size, + "num_stages": num_stages, + "constants": _get_jsonable_constants(constants), + "debug": debug, + "use_sme": use_sme} + if ext == "ptx": + assert "shared" in kwargs, "ptx compilation must provide shared memory size" + metadata["shared"] = kwargs["shared"] + + first_stage = list(stages.keys()).index(ext) + asm = dict() + module = fn + # run compilation pipeline and populate metadata + for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]: + ir_filename = f"{name}.{ir}" + if ir == ext: + next_module = parse(fn) + else: + path = metadata_group.get(ir_filename) + if path is None: + next_module = compile_kernel(module) + if ir == "amdgcn": + extra_file_name = f"{name}.hsaco_path" + metadata_group[ir_filename] = fn_cache_manager.put(next_module[0], ir_filename) + metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) + else: + metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) + fn_cache_manager.put(next_module, ir_filename) + else: + if ir == "amdgcn": + extra_file_name = f"{name}.hsaco_path" + hasco_path = metadata_group.get(extra_file_name) + assert hasco_path is not None, "Expected to have hsaco_path in metadata when we have the amdgcn" + next_module = (parse(path), parse(hasco_path)) + else: + next_module = parse(path) + + if ir == "llir": + metadata["name"] = get_kernel_name(next_module, pattern="define iluvatar_kernel void", llir=True) + if ir == "cubin": + asm[ir] = next_module + elif ir == "amdgcn": + asm[ir] = str(next_module[0]) + else: + asm[ir] = str(next_module) + if ir == "llir" and "shared" not in metadata: + metadata["shared"] = _triton.get_shared_memory_size(module) + if ir == "ptx": + metadata["name"] = get_kernel_name(next_module, pattern='// .globl') + if ir == "amdgcn": + metadata["name"] = get_kernel_name(next_module[0], pattern='.globl') + asm["hsaco_path"] = next_module[1] + module = next_module + # write-back metadata, if it didn't come from the cache + if metadata_path is None: + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata), metadata_filename, binary=False) + fn_cache_manager.put_group(metadata_filename, metadata_group) + + # return handle to compiled kernel + return CompiledKernel(fn, so_path, metadata, asm) + + +class CompiledKernel: + + # Hooks for external tools to monitor the execution of triton kernels + launch_enter_hook = None + launch_exit_hook = None + + def __init__(self, fn, so_path, metadata, asm): + # initialize launcher + import importlib.util + spec = importlib.util.spec_from_file_location("__triton_launcher", so_path) + mod = importlib.util.module_from_spec(spec) + self.fn = fn + spec.loader.exec_module(mod) + self.c_wrapper = getattr(mod, "launch") + # initialize metadata + self.shared = metadata["shared"] + self.num_warps = metadata["num_warps"] + self.warp_size = metadata["warp_size"] + self.num_stages = metadata["num_stages"] + self.constants = metadata["constants"] + # initialize asm dict + self.asm = asm + # binaries are lazily initialized + # because it involves doing runtime things + # (e.g., checking amount of shared memory on current device) + self.metadata = metadata + self.cu_module = None + self.cu_function = None + + def _init_handles(self): + if self.cu_module is not None: + return + device = triton.runtime.jit.get_current_device() + bin_path = { + driver.HIP: "hsaco_path", + driver.CUDA: "cubin" + }[driver.backend] + max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] + if self.shared > max_shared: + raise OutOfResources(self.shared, max_shared, "shared memory") + mod, func, n_regs, n_spills = driver.utils.load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device) + + self.n_spills = n_spills + self.n_regs = n_regs + self.cu_module = mod + self.cu_function = func + + def __getattribute__(self, name): + if name == 'c_wrapper': + self._init_handles() + return super().__getattribute__(name) + + def __getitem__(self, grid): + self._init_handles() + + def runner(*args, stream=None): + if stream is None: + stream = triton.runtime.jit.get_cuda_stream() + self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.shared, stream, self.cu_function, + CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args) + return runner + + def get_sass(self, fun=None): + if 'sass' in self.asm: + return self.asm['sass'] + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(self.asm['cubin']) + self.sass = extract(path, fun) + finally: + os.remove(path) + self.asm['sass'] = self.sass + return self.sass diff --git a/pkgs/triton/compiler/errors.py b/pkgs/triton/compiler/errors.py new file mode 100644 index 0000000..2930117 --- /dev/null +++ b/pkgs/triton/compiler/errors.py @@ -0,0 +1,52 @@ +import ast +from typing import Optional, Union + + +class CompilationError(Exception): + source_line_count_max_in_message = 12 + + def _format_message(self) -> str: + node = self.node + if self.src is None: + source_excerpt = " " + else: + source_excerpt = self.src.split('\n')[:node.lineno][-self.source_line_count_max_in_message:] + if source_excerpt: + source_excerpt.append(' ' * node.col_offset + '^') + source_excerpt = '\n'.join(source_excerpt) + else: + source_excerpt = " " + + message = "at {}:{}:{}".format(node.lineno, node.col_offset, source_excerpt) + if self.error_message: + message += '\n' + self.error_message + return message + + def __init__(self, src: Optional[str], node: ast.AST, error_message: Union[str, None]): + self.src = src + self.node = node + self.error_message = error_message + self.message = self._format_message() + + def set_source_code(self, src: Optional[str]): + self.src = src + self.message = self._format_message() + + def __str__(self): + return self.message + + def __repr__(self): + return "{}({!r})".format(type(self).__name__, self.message) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return type(self), (self.src, self.node, self.error_message) + + +class CompileTimeAssertionFailure(CompilationError): + """Specific exception for failed tests in `static_assert` invocations""" + pass + + +class UnsupportedLanguageConstruct(CompilationError): + pass diff --git a/pkgs/triton/compiler/make_launcher.py b/pkgs/triton/compiler/make_launcher.py new file mode 100644 index 0000000..0ab04b1 --- /dev/null +++ b/pkgs/triton/compiler/make_launcher.py @@ -0,0 +1,392 @@ +import hashlib +import os +import tempfile + +from ..common import _build +from ..runtime.cache import get_cache_manager +from ..runtime.jit import version_key + + +def is_hip(): + import torch + return torch.version.hip is not None + + +def is_corex(): + import torch + return hasattr(torch, "corex") and torch.corex == True + + +# ----- stub -------- + + +def make_so_cache_key(version_hash, signature, constants): + # Get unique key for the compiled code + signature = {k: 'ptr' if v[0] == '*' else v for k, v in signature.items()} + key = f"{version_hash}-{''.join(signature.values())}{constants}" + key = hashlib.md5(key.encode("utf-8")).hexdigest() + return key + + +def make_stub(name, signature, constants): + # name of files that are cached + so_cache_key = make_so_cache_key(version_key(), signature, constants) + so_cache_manager = get_cache_manager(so_cache_key) + so_name = f"{name}.so" + # retrieve stub from cache if it exists + cache_path = so_cache_manager.get_file(so_name) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src = generate_launcher(constants, signature) + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build(name, src_path, tmpdir) + so_cache_manager.put(src, f"{name}.c", binary=False) + with open(so, "rb") as f: + return so_cache_manager.put(f.read(), so_name, binary=True) + else: + return cache_path + +# ----- source code generation -------- + + +def ty_to_cpp(ty): + if ty[0] == '*': + return "hipDeviceptr_t" if is_hip() else "CUdeviceptr" + return { + "i1": "int32_t", + "i8": "int8_t", + "i16": "int16_t", + "i32": "int32_t", + "i64": "int64_t", + "u32": "uint32_t", + "u64": "uint64_t", + "fp16": "float", + "bf16": "float", + "fp32": "float", + "f32": "float", + "fp64": "double", + }[ty] + + +def generate_launcher(constants, signature): + arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) + + def _extracted_type(ty): + if ty[0] == '*': + return "PyObject*" + return { + 'i1': 'int32_t', + 'i32': 'int32_t', + 'i64': 'int64_t', + 'u32': 'uint32_t', + 'u64': 'uint64_t', + 'fp16': 'float', + 'bf16': 'float', + 'fp32': 'float', + 'f32': 'float', + 'fp64': 'double', + }[ty] + + def format_of(ty): + return { + "PyObject*": "O", + "float": "f", + "double": "d", + "long": "l", + "uint32_t": "I", + "int32_t": "i", + "uint64_t": "K", + "int64_t": "L", + }[ty] + + format = "iiiiiKKOOO" + ''.join([format_of(_extracted_type(ty)) for ty in signature.values()]) + + # generate glue code + if is_hip(): + src = f""" + #define __HIP_PLATFORM_AMD__ + #include + #include + #include + + static inline void gpuAssert(hipError_t code, const char *file, int line) + {{ + if (code != HIP_SUCCESS) + {{ + const char* prefix = "Triton Error [HIP]: "; + const char* str = hipGetErrorString(code); + char err[1024] = {{0}}; + snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str ); + PyErr_SetString(PyExc_RuntimeError, err); + }} + }} + + #define HIP_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + + static int getWarpSize(hipStream_t stream) + {{ + int device_id = hipGetStreamDeviceId(stream); + gpuAssert(device_id >= 0 ? hipSuccess : hipErrorInvalidDevice, __FILE__, __LINE__); + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, device_id)); + return prop.warpSize; + }} + + static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, hipStream_t stream, hipFunction_t function, {arg_decls}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + if (gridX*gridY*gridZ > 0) {{ + int warp_size = getWarpSize(stream); + HIP_CHECK(hipModuleLaunchKernel(function, gridX, gridY, gridZ, num_warps * warp_size, 1, 1, shared_memory, stream, params, 0)); + }} + }} + + typedef struct _DevicePtrInfo {{ + hipDeviceptr_t dev_ptr; + bool valid; + }} DevicePtrInfo; + + static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + + if (ptr) {{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + + ptr_info.dev_ptr = (hipDeviceptr_t)PyLong_AsUnsignedLongLong(ret); + + if (!ptr_info.dev_ptr) + return ptr_info; + + uint64_t dev_ptr; + hipError_t status = hipPointerGetAttribute(&dev_ptr, HIP_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == hipErrorInvalidValue) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + + ptr_info.dev_ptr = (hipDeviceptr_t)dev_ptr; + return ptr_info; + }} + + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; + }} + + static PyObject* launch(PyObject* self, PyObject* args) {{ + + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int num_warps; + int shared_memory; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if (!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None) {{ + PyObject_CallObject(launch_enter_hook, args); + }} + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, shared_memory, (hipStream_t)_stream, (hipFunction_t)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}" for i, ty in signature.items())}); + if (launch_exit_hook != Py_None) {{ + PyObject_CallObject(launch_exit_hook, args); + }} + if (PyErr_Occurred()) {{ + return NULL; + }} + + // return None + Py_INCREF(Py_None); + return Py_None; + }} + + static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel + }}; + + static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods + }}; + + PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; + }} + """ + else: + warp_size = 64 if is_corex() else 32 + src = f""" +#include \"cuda.h\" +#include +#include + +static inline void gpuAssert(CUresult code, const char *file, int line) +{{ + if (code != CUDA_SUCCESS) + {{ + const char* prefix = "Triton Error [CUDA]: "; + const char* str; + cuGetErrorString(code, &str); + char err[1024] = {{0}}; + strcat(err, prefix); + strcat(err, str); + PyErr_SetString(PyExc_RuntimeError, err); + }} +}} + +#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }} + +static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{ + void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }}; + if(gridX*gridY*gridZ > 0){{ + CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, num_warps * {warp_size}, 1, 1, shared_memory, stream, params, 0)); + }} +}} + +typedef struct _DevicePtrInfo {{ + CUdeviceptr dev_ptr; + bool valid; +}} DevicePtrInfo; + +static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{ + DevicePtrInfo ptr_info; + ptr_info.dev_ptr = 0; + ptr_info.valid = true; + if (PyLong_Check(obj)) {{ + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj); + return ptr_info; + }} + if (obj == Py_None) {{ + // valid nullptr + return ptr_info; + }} + PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr"); + if(ptr){{ + PyObject *empty_tuple = PyTuple_New(0); + PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL); + Py_DECREF(empty_tuple); + Py_DECREF(ptr); + if (!PyLong_Check(ret)) {{ + PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int"); + ptr_info.valid = false; + return ptr_info; + }} + ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret); + if(!ptr_info.dev_ptr) + return ptr_info; + /* + uint64_t dev_ptr; + int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr); + if (status == CUDA_ERROR_INVALID_VALUE) {{ + PyErr_Format(PyExc_ValueError, + "Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx); + ptr_info.valid = false; + }} + ptr_info.dev_ptr = dev_ptr; + */ + Py_DECREF(ret); // Thanks ChatGPT! + return ptr_info; + }} + PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method"); + return ptr_info; +}} + +static PyObject* launch(PyObject* self, PyObject* args) {{ + int gridX, gridY, gridZ; + uint64_t _stream; + uint64_t _function; + int num_warps; + int shared_memory; + PyObject *launch_enter_hook = NULL; + PyObject *launch_exit_hook = NULL; + PyObject *compiled_kernel = NULL; + {' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])} + if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{ + return NULL; + }} + + if (launch_enter_hook != Py_None) {{ + PyObject_CallObject(launch_enter_hook, args); + }} + + + // raise exception asap + {"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])}; + _launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())}); + + if (launch_exit_hook != Py_None) {{ + PyObject_CallObject(launch_exit_hook, args); + }} + + if(PyErr_Occurred()) {{ + return NULL; + }} + // return None + Py_INCREF(Py_None); + return Py_None; +}} + +static PyMethodDef ModuleMethods[] = {{ + {{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, + {{NULL, NULL, 0, NULL}} // sentinel +}}; + +static struct PyModuleDef ModuleDef = {{ + PyModuleDef_HEAD_INIT, + \"__triton_launcher\", + NULL, //documentation + -1, //size + ModuleMethods +}}; + +PyMODINIT_FUNC PyInit___triton_launcher(void) {{ + PyObject *m = PyModule_Create(&ModuleDef); + if(m == NULL) {{ + return NULL; + }} + PyModule_AddFunctions(m, ModuleMethods); + return m; +}} +""" + return src diff --git a/pkgs/triton/debugger/__init__.py b/pkgs/triton/debugger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/triton/debugger/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/debugger/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..2db90c9 Binary files /dev/null and b/pkgs/triton/debugger/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/debugger/__pycache__/core.cpython-310.pyc b/pkgs/triton/debugger/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000..3e639df Binary files /dev/null and b/pkgs/triton/debugger/__pycache__/core.cpython-310.pyc differ diff --git a/pkgs/triton/debugger/__pycache__/debugger.cpython-310.pyc b/pkgs/triton/debugger/__pycache__/debugger.cpython-310.pyc new file mode 100644 index 0000000..161360b Binary files /dev/null and b/pkgs/triton/debugger/__pycache__/debugger.cpython-310.pyc differ diff --git a/pkgs/triton/debugger/__pycache__/memory_map.cpython-310.pyc b/pkgs/triton/debugger/__pycache__/memory_map.cpython-310.pyc new file mode 100644 index 0000000..9701f71 Binary files /dev/null and b/pkgs/triton/debugger/__pycache__/memory_map.cpython-310.pyc differ diff --git a/pkgs/triton/debugger/__pycache__/tl_lang.cpython-310.pyc b/pkgs/triton/debugger/__pycache__/tl_lang.cpython-310.pyc new file mode 100644 index 0000000..2077d73 Binary files /dev/null and b/pkgs/triton/debugger/__pycache__/tl_lang.cpython-310.pyc differ diff --git a/pkgs/triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc b/pkgs/triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc new file mode 100644 index 0000000..7e90420 Binary files /dev/null and b/pkgs/triton/debugger/__pycache__/torch_wrapper.cpython-310.pyc differ diff --git a/pkgs/triton/debugger/core.py b/pkgs/triton/debugger/core.py new file mode 100644 index 0000000..82f3f43 --- /dev/null +++ b/pkgs/triton/debugger/core.py @@ -0,0 +1,9 @@ +from typing import Tuple + +import dataclasses + + +@dataclasses.dataclass +class ExecutionContext: + program_id: Tuple[int] + program_size: Tuple[int] diff --git a/pkgs/triton/debugger/debugger.py b/pkgs/triton/debugger/debugger.py new file mode 100644 index 0000000..5c5b972 --- /dev/null +++ b/pkgs/triton/debugger/debugger.py @@ -0,0 +1,170 @@ +import itertools +import random +from typing import Tuple + +import triton +import triton.language as tl +from .core import ExecutionContext +from .memory_map import MemoryMap +from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor, + debugger_constexpr) +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch +tl_method_backup = {} + + +def get_proxy_method(proxy, name): + method = getattr(proxy, name) + + def fun(*args, **kwarg): + return method(*args, **kwarg) + + return fun + + +def attach_triton(module, proxy): + method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"] + for name in method_list: + if hasattr(module, name): + attr = getattr(module, name) + tl_method_backup[name] = attr + if callable(attr): + setattr(module, name, get_proxy_method(proxy, name)) + else: + setattr(module, name, getattr(proxy, name)) + + +def detach_triton(module): + for name, method in tl_method_backup.items(): + setattr(module, name, method) + + +def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]: + # reverse the grid dimensions and generate the range for each dimension + reversed_grid = reversed(grid) + ranges_for_each_dimension = [range(dim) for dim in reversed_grid] + + # gen all combinations + index_combinations = list(itertools.product(*ranges_for_each_dimension)) + random.shuffle(index_combinations) + + for index_combination in index_combinations: + yield index_combination + + +class DebuggerFunction: + def __init__(self, func, grid=(1,)): + self.func = func + self.grid = grid + + def _is_constexpr(self, name): + return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr + + def _get_constexpr(self): + result = [] + for name, annotation in self.func.__annotations__.items(): + if annotation is triton.language.core.constexpr: + result.append(name) + return result + + def _assert_constexpr(self, **kwargs): + constexp = self._get_constexpr() + missing = [i for i in constexp if i not in kwargs.keys()] + assert len(missing) == 0, f"You must specify constexpr {missing}" + + def _get_grid(self, **kwargs): + if callable(self.grid): + return self.grid(kwargs) + else: + return self.grid + + def __call__(self, *args, **kwargs): + self._assert_constexpr(**kwargs) + + memory = MemoryMap() + + def convert_arg(v): + name, arg = v + if torch.is_tensor(arg): + ptr = memory.add_tensor(arg) + return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda")) + if self._is_constexpr(name): + return debugger_constexpr(arg) + return WrappedTensor(_primitive_to_tensor(arg)) + + new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args))) + new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]} + + grid = self._get_grid(**kwargs) + for program_id in program_ids_from_grid(grid): + proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid)) + attach_triton(tl, proxy) + self.func(*new_args, **new_kwargs) + detach_triton(tl) + + +class GridSelector: + """ + Entry point of the debugger + """ + + def __init__(self, func): + version = torch.__version__ + assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}" + self.func = func + + def __getitem__(self, grid): + return DebuggerFunction(self.func, grid) + + def __call__(self, *args, **kwargs): + return DebuggerFunction(self.func)(*args, **kwargs) + + +class AutotuneGridSelector: + def __init__(self, func, autotune_params): + self.func = func + self.autotune_params = autotune_params + + def __getitem__(self, grid): + return AutotuneRunner(self.func, self.autotune_params, grid) + + def __call__(self, *args, **kwargs): + return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs) + + +class AutotuneRunner: + def __init__(self, func, autotune_params, grid=None): + self.func = func + self.autotune_params = autotune_params + self.grid = grid + + def __call__(self, *args, **kwargs): + assert len(self.autotune_params["configs"]) >= 1 + + for config in self.autotune_params["configs"][1:]: + + def convert_arg(v): + if torch.is_tensor(v): + return torch.clone(v) + return v + + new_args = tuple(map(convert_arg, args)) + new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} + if self.grid: + self.func[self.grid](*new_args, **new_kwargs, **config.kwargs) + else: + self.func(*new_args, **new_kwargs, **config.kwargs) + + main_config = self.autotune_params["configs"][0] + if self.grid: + self.func[self.grid](*args, **kwargs, **main_config.kwargs) + else: + self.func(*args, **kwargs, **main_config.kwargs) + + +def triton_debug_autotune(**kwars): + def wrapper(func): + return AutotuneGridSelector(func, kwars) + + return wrapper diff --git a/pkgs/triton/debugger/memory_map.py b/pkgs/triton/debugger/memory_map.py new file mode 100644 index 0000000..edf4c3f --- /dev/null +++ b/pkgs/triton/debugger/memory_map.py @@ -0,0 +1,100 @@ +import dataclasses + +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch + + +@dataclasses.dataclass +class RegisteredStorage: + storage: torch.Storage + dtype: torch.dtype + size: int + ptr: int + + @property + def end_ptr(self) -> int: + return self.ptr + self.size + + @property + def access_tensor(self) -> torch.Tensor: + return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device) + + def ensure_immutable(self): + assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size + + +class MemoryMap: + storages: [RegisteredStorage] + + def __init__(self): + self.storages = [] + + def _get_registered_storage(self, pointer: torch.Tensor): + max_pointer = torch.max(pointer).item() + min_pointer = torch.min(pointer).item() + + registered_storage = next( + filter( + lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages + ), + None, + ) + if registered_storage is None: + raise Exception("Storage not found or pointers spanning multiple tensors") + registered_storage.ensure_immutable() + return registered_storage + + def add_tensor(self, t: torch.Tensor): + storage = t.untyped_storage() + self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr())) + return t.data_ptr() + + def load( + self, + pointer: torch.Tensor, + mask: torch.Tensor = None, + other=0.0, + ): + assert pointer.is_cuda + assert 0 < pointer.dim() < 3 + assert pointer.dtype == torch.int64 + + if mask is None: + mask = torch.ones_like(pointer).bool() + assert mask.is_cuda + assert 0 < mask.dim() < 3 + assert mask.dtype == torch.bool + mask = mask.expand(pointer.size()) + + if torch.all(~mask): + # Todo: The type is wrong here, we can't determine the correct type + return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda") + + registered_storage = self._get_registered_storage(pointer[mask]) + access_tensor = registered_storage.access_tensor + + index_tensor = pointer - registered_storage.ptr + + block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda") + block[mask] = access_tensor[index_tensor[mask]] + return block + + def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): + assert 0 < pointer.dim() < 3 + assert pointer.dtype == torch.int64 + + if mask is None: + mask = torch.ones_like(pointer).bool() + assert 0 < mask.dim() < 3 + assert mask.dtype == torch.bool + mask = mask.expand(pointer.size()) + + if torch.all(~mask): + return + + registered_storage = self._get_registered_storage(pointer[mask]) + access_tensor = registered_storage.access_tensor + + index_tensor = pointer - registered_storage.ptr + access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype) diff --git a/pkgs/triton/debugger/tl_lang.py b/pkgs/triton/debugger/tl_lang.py new file mode 100644 index 0000000..6364b77 --- /dev/null +++ b/pkgs/triton/debugger/tl_lang.py @@ -0,0 +1,621 @@ +import triton +from .core import ExecutionContext +from .memory_map import MemoryMap +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch + + +def _primitive_to_tensor(x): + """ + Converts various Python primitive data types to PyTorch tensor. + """ + tensor_args = {"device": "cuda"} + if isinstance(x, bool): + return torch.tensor([x], dtype=torch.bool, **tensor_args) + elif isinstance(x, int): + if -(2**31) <= x < 2**31: + return torch.tensor([x], dtype=torch.int32, **tensor_args) + elif -(2**63) <= x < 2**63: + return torch.tensor([x], dtype=torch.int64, **tensor_args) + else: + raise RuntimeError(f"Nonrepresentable integer {x}.") + elif isinstance(x, float): + return torch.tensor([x], dtype=torch.float32, **tensor_args) + elif torch.is_tensor(x): + return x + elif isinstance(x, WrappedTensor): + return x + elif isinstance(x, debugger_constexpr): + if x.value is None: + return None + return _primitive_to_tensor(x.value) + elif x is None: + return None + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +def _infer_tensor(func): + """ + A decorator function to harmonize function args: + - converts primitives to PyTorch tensors + - wraps PyTorch tensors with WrappedTensors + """ + def wrapper(*args): + new_args = tuple(map(lambda v: _primitive_to_tensor(v), args)) + new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args)) + + return func(*new_args) + + return wrapper + + +def _tensor_operation(func): + """ + A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function. + Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor). + """ + def wrapper(*args, **kwargs): + for arg in args: + assert not torch.is_tensor(arg), "unexpected tensor argument" + + def unwrap_tensor(v): + if isinstance(v, WrappedTensor): + return v.tensor + if isinstance(v, debugger_constexpr): + return v.value + return v + + new_args = tuple(map(unwrap_tensor, args)) + new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()} + + result = func(args[0], *new_args[1:], **new_kwargs) + return WrappedTensor(result) if torch.is_tensor(result) else result + + return wrapper + + +class debugger_constexpr: + def __init__(self, value): + if isinstance(value, debugger_constexpr): + self.value = value.value + else: + self.value = value + + def __str__(self) -> str: + return "debugger_constexpr(" + str(self.value) + ")" + + def __index__(self) -> int: + return self.value + + def __bool__(self): + return bool(self.value) + + def __ge__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value >= other + + def __gt__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value > other + + def __le__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value <= other + + def __lt__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value < other + + def __eq__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value == other + + def __or__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value | other + + def __ror__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value | other + + def __and__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value & other + + def __rand__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value & other + + def to(self, dtype, bitcast=False, _builder=None): + if dtype in [torch.int64]: + ret_ty = int + elif dtype == torch.bool: + ret_ty = bool + elif dtype in [torch.float64]: + ret_ty = float + else: + raise ValueError("dtype not supported in debugger") + return debugger_constexpr(ret_ty(self.value)) + + +class WrappedTensor: + def __init__(self, tensor): + self.tensor = tensor + + def __index__(self) -> int: + return self.tensor.item() + + def __str__(self) -> str: + return "wrapped_" + str(self.tensor) + + def __bool__(self) -> bool: + return torch.all(self.tensor == True).item() # noqa: E712 + + @property + def dtype(self): + return self.tensor.dtype + + @_infer_tensor + @_tensor_operation + def __add__(self, other): + return torch.add(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __radd__(self, other): + return self.__add__(other) + + @_infer_tensor + @_tensor_operation + def __sub__(self, other): + return torch.sub(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rsub__(self, other): + return torch.sub(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __mul__(self, other): + return torch.mul(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rmul__(self, other): + return self.__mul__(other) + + @_infer_tensor + @_tensor_operation + def __truediv__(self, other): + return torch.div(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rtruediv__(self, other): + return torch.div(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __floordiv__(self, other): + return torch.floor_divide(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rfloordiv__(self, other): + return torch.floor_divide(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __mod__(self, other): + return torch.remainder(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rmod__(self, other): + return torch.remainder(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __neg__(self): + return -self.tensor + + @_infer_tensor + @_tensor_operation + def __invert__(self): + return ~self.tensor + + @_infer_tensor + @_tensor_operation + def __and__(self, other): + return torch.bitwise_and(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __or__(self, other): + return torch.bitwise_or(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __xor__(self, other): + return torch.bitwise_xor(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __lshift__(self, other): + return torch.bitwise_left_shift(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rshift__(self, other): + return torch.bitwise_right_shift(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __gt__(self, other): + return self.tensor > other + + @_infer_tensor + @_tensor_operation + def __rgt__(self, other): + return other > self.tensor + + @_infer_tensor + @_tensor_operation + def __ge__(self, other): + return self.tensor >= other + + @_infer_tensor + @_tensor_operation + def __rge__(self, other): + return other >= self.tensor + + @_infer_tensor + @_tensor_operation + def __lt__(self, other): + return self.tensor < other + + @_infer_tensor + @_tensor_operation + def __rlt__(self, other): + return other < self.tensor + + @_infer_tensor + @_tensor_operation + def __le__(self, other): + return self.tensor <= other + + @_infer_tensor + @_tensor_operation + def __rle__(self, other): + return other <= self.tensor + + @_infer_tensor + @_tensor_operation + def __eq__(self, other): + return torch.equal(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __ne__(self, other): + return not torch.equal(self.tensor, other) + + @_tensor_operation + def __getitem__(self, slices): + return self.tensor.__getitem__(slices) + # if isinstance(slices, slice): + # slices = [slices] + # src_shape = self.shape + # dst_shape = [] + # curr = 0 + # for sl in slices: + # if isinstance(sl, constexpr) and sl.value is None: + # dst_shape.append(1) + # elif sl == slice(None, None, None): + # dst_shape.append(src_shape[curr].value) + # curr += 1 + # ret = torch.reshape(self.tensor, dst_shape, ) + # return ret + + @_tensor_operation + def to(self, dtype, bitcast=False): + return self.tensor.to(dtype) + # if isinstance(bitcast, constexpr): + # bitcast = bitcast.value + # if bitcast: + # return semantic.bitcast(self, dtype, ) + # return semantic.cast(self, dtype, ) + + +def _constexpr_to_value(v): + if isinstance(v, debugger_constexpr): + return v.value + return v + + +class TritonLangProxy: + _memory_map: MemoryMap + _context: ExecutionContext + + def __init__(self, memory_map: MemoryMap, context: ExecutionContext): + self._memory_map = memory_map + self._context = context + + # Types + # Removed void, int1, float8, uint16, uint32, uint64, pi32_t + + # constexpr = debugger_constexpr + + # Program functions + + @_tensor_operation + def load( + self, + pointer: torch.Tensor, + mask: torch.Tensor = None, + other=0.0, + cache_modifier="", + eviction_policy="", + volatile=False, + ): + return self._memory_map.load(pointer, mask, other) + + @_tensor_operation + def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): + return self._memory_map.store(pointer, value, mask) + + @_tensor_operation + def program_id(self, axis): + assert axis < len(self._context.program_id) + return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda") + + @_tensor_operation + def num_programs(self, axis): + assert axis < len(self._context.program_size) + return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda") + + @_tensor_operation + def arange(self, start, end): + return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda") + + @_tensor_operation + def zeros(self, shape, dtype): + for i, d in enumerate(shape): + if not isinstance(d, debugger_constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + shape = [x.value for x in shape] + if isinstance(dtype, triton.language.core.dtype): + if dtype.is_fp32(): + dtype = torch.float32 + elif dtype.is_fp16(): + dtype = torch.float16 + elif dtype.is_bf16(): + dtype = torch.bfloat16 + elif dtype.is_int32(): + dtype = torch.int32 + elif dtype.is_int16(): + dtype = torch.int16 + elif dtype.is_int8(): + dtype = torch.int8 + else: + raise TypeError(f"Unsupported dtype {dtype}") + return torch.zeros(size=shape, dtype=dtype, device="cuda") + + @_tensor_operation + def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16): + raise NotImplementedError() + + @_tensor_operation + def broadcast(self, input, other): + raise NotImplementedError() + + @_tensor_operation + def broadcast_to(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def cat(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def reshape(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): + assert input.dtype == other.dtype + if trans_a: + input = input.T + if trans_b: + other = other.T + return torch.matmul(input=input, other=other) + + @_tensor_operation + def atomic_cas(self, pointer, cmp, val): + stored = self._memory_map.load(pointer, None, 0.0) + if not isinstance(cmp, torch.Tensor): + cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda") + if not isinstance(val, torch.Tensor): + val = torch.tensor([val], dtype=stored.dtype, device="cuda") + if stored == cmp: + self._memory_map.store(pointer, val, None) + return stored + + @_tensor_operation + def atomic_xchg(self, pointer, val, mask=None): + if isinstance(val, int): + val = torch.tensor([val], dtype=torch.int32, device="cuda") + stored = self._memory_map.load(pointer, mask, 0.0) + self._memory_map.store(pointer, val, mask) + return stored + + @_tensor_operation + def atomic_add(self, pointer, val, mask=None): + # arbitrary other value as it will masked during storing + stored = self._memory_map.load(pointer, mask, 0.0) + result = stored + val + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_max(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0.0) + result = torch.maximum(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_min(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0.0) + result = torch.minimum(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_and(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_and(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_or(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_or(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_xor(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_xor(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def where(self, condition, x, y): + condition = _primitive_to_tensor(condition) + x = _primitive_to_tensor(x) + y = _primitive_to_tensor(y) + return torch.where(condition, x, y) + + @_tensor_operation + def umulhi(self, x, y): + raise NotImplementedError() + + @_tensor_operation + def fdiv(self, x, y, ieee_rounding=False): + raise NotImplementedError() + + @_tensor_operation + def exp(self, x): + return torch.exp(x) + + @_tensor_operation + def log(self, x): + return torch.log(x) + + @_tensor_operation + def cos(self, x): + return torch.cos(x) + + @_tensor_operation + def sin(self, x): + return torch.sin(x) + + @_tensor_operation + def sqrt(self, x): + return torch.sqrt(x) + + @_tensor_operation + def globaltimer(self): + raise NotImplementedError() + + @_tensor_operation + def clock(self): + raise NotImplementedError() + + @_tensor_operation + def debug_barrier(self): + raise NotImplementedError() + + @_tensor_operation + def multiple_of(self, input, values): + return input + + @_tensor_operation + def max_contiguous(self, input, values): + return input + + @_tensor_operation + def abs(self, x): + return torch.abs(x) + + @_tensor_operation + def cdiv(self, x, div): + return (x + div - 1) // div + + @_tensor_operation + def minimum(self, x, y): + if isinstance(x, int): + x = torch.tensor(x, device="cuda") + if isinstance(y, int): + y = torch.tensor(y, device="cuda") + return torch.minimum(x, y) + + @_tensor_operation + def maximum(self, x, y): + return torch.maximum(x, y) + + @_tensor_operation + def sigmoid(self, x): + raise NotImplementedError() + + @_tensor_operation + def softmax(self, x, ieee_rounding=False): + raise NotImplementedError() + + @_tensor_operation + def ravel(self, x): + raise NotImplementedError() + + @_tensor_operation + def swizzle2d(self, i, j, size_i, size_j, size_g): + raise NotImplementedError() + + @_tensor_operation + def zeros_like(self, input): + raise NotImplementedError() + + @_tensor_operation + def max(self, input, axis=None): + if axis is None: + return torch.max(input) + return torch.max(input, dim=axis).values + + @_tensor_operation + def argmax(self, input, axis): + raise NotImplementedError() + + @_tensor_operation + def min(self, input, axis=None): + if axis is None: + return torch.min(input) + return torch.min(input, dim=axis).values + + @_tensor_operation + def argmin(self, input, axis): + raise NotImplementedError() + + @_tensor_operation + def sum(self, input, axis=None): + if axis is None: + return torch.sum(input) + return torch.sum(input, dim=axis) + + @_tensor_operation + def xor_sum(self, input, axis): + raise NotImplementedError() diff --git a/pkgs/triton/debugger/torch_wrapper.py b/pkgs/triton/debugger/torch_wrapper.py new file mode 100644 index 0000000..44aa17e --- /dev/null +++ b/pkgs/triton/debugger/torch_wrapper.py @@ -0,0 +1,18 @@ +try: + import torch as _torch +except ImportError: + _torch = None + + +class TorchWrapper: + """ + Helps in making torch an optional dependency + """ + + def __getattr__(self, name): + if _torch is None: + raise ImportError("Triton requires PyTorch to be installed") + return getattr(_torch, name) + + +torch = TorchWrapper() diff --git a/pkgs/triton/language/__init__.py b/pkgs/triton/language/__init__.py new file mode 100644 index 0000000..7485f37 --- /dev/null +++ b/pkgs/triton/language/__init__.py @@ -0,0 +1,201 @@ +"""isort:skip_file""" +# Import order is significant here. + +from . import math +from . import extra +from .standard import ( + cdiv, + sigmoid, + softmax, + ravel, + swizzle2d, + zeros, + zeros_like, +) +from .core import ( + abs, + advance, + arange, + argmin, + argmax, + atomic_add, + atomic_and, + atomic_cas, + atomic_max, + atomic_min, + atomic_or, + atomic_xchg, + atomic_xor, + bfloat16, + block_type, + broadcast, + broadcast_to, + cat, + constexpr, + cos, + debug_barrier, + device_assert, + device_print, + dot, + dtype, + exp, + expand_dims, + full, + fdiv, + float16, + float32, + float64, + float8e4, + float8e5, + function_type, + int1, + int16, + int32, + int64, + int8, + load, + log, + make_block_ptr, + max, + max_contiguous, + maximum, + min, + minimum, + multiple_of, + num_programs, + pi32_t, + pointer_type, + program_id, + reduce, + reshape, + sin, + sqrt, + static_assert, + static_print, + store, + sum, + static_range, + tensor, + trans, + triton, + uint16, + uint32, + uint64, + uint8, + umulhi, + view, + void, + where, + xor_sum, +) +from .random import ( + pair_uniform_to_normal, + philox, + philox_impl, + rand, + rand4x, + randint, + randint4x, + randn, + randn4x, + uint32_to_uniform_float, +) + + +__all__ = [ + "abs", + "advance", + "arange", + "argmin", + "argmax", + "atomic_add", + "atomic_and", + "atomic_cas", + "atomic_max", + "atomic_min", + "atomic_or", + "atomic_xchg", + "atomic_xor", + "bfloat16", + "block_type", + "broadcast", + "broadcast_to", + "builtin", + "cat", + "cdiv", + "constexpr", + "cos", + "debug_barrier", + "device_assert", + "device_print", + "dot", + "dtype", + "exp", + "expand_dims", + "extra", + "fdiv", + "float16", + "float32", + "float64", + "float8e4", + "float8e5", + "full", + "function_type", + "int1", + "int16", + "int32", + "int64", + "int8", + "ir", + "math", + "load", + "log", + "make_block_ptr", + "max", + "max_contiguous", + "maximum", + "min", + "minimum", + "multiple_of", + "num_programs", + "pair_uniform_to_normal", + "philox", + "philox_impl", + "pi32_t", + "pointer_type", + "program_id", + "rand", + "rand4x", + "randint", + "randint4x", + "randn", + "randn4x", + "ravel", + "reduce", + "reshape", + "sigmoid", + "sin", + "softmax", + "sqrt", + "static_range", + "static_assert", + "static_print", + "store", + "sum", + "swizzle2d", + "tensor", + "trans", + "triton", + "uint16", + "uint32", + "uint32_to_uniform_float", + "uint64", + "uint8", + "umulhi", + "view", + "void", + "where", + "xor_sum", + "zeros", + "zeros_like", +] diff --git a/pkgs/triton/language/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/language/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6030007 Binary files /dev/null and b/pkgs/triton/language/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/language/__pycache__/core.cpython-310.pyc b/pkgs/triton/language/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000..e2de90e Binary files /dev/null and b/pkgs/triton/language/__pycache__/core.cpython-310.pyc differ diff --git a/pkgs/triton/language/__pycache__/math.cpython-310.pyc b/pkgs/triton/language/__pycache__/math.cpython-310.pyc new file mode 100644 index 0000000..bd7e12d Binary files /dev/null and b/pkgs/triton/language/__pycache__/math.cpython-310.pyc differ diff --git a/pkgs/triton/language/__pycache__/random.cpython-310.pyc b/pkgs/triton/language/__pycache__/random.cpython-310.pyc new file mode 100644 index 0000000..f8c9f32 Binary files /dev/null and b/pkgs/triton/language/__pycache__/random.cpython-310.pyc differ diff --git a/pkgs/triton/language/__pycache__/semantic.cpython-310.pyc b/pkgs/triton/language/__pycache__/semantic.cpython-310.pyc new file mode 100644 index 0000000..dbfbc42 Binary files /dev/null and b/pkgs/triton/language/__pycache__/semantic.cpython-310.pyc differ diff --git a/pkgs/triton/language/__pycache__/standard.cpython-310.pyc b/pkgs/triton/language/__pycache__/standard.cpython-310.pyc new file mode 100644 index 0000000..3bc1f6e Binary files /dev/null and b/pkgs/triton/language/__pycache__/standard.cpython-310.pyc differ diff --git a/pkgs/triton/language/core.py b/pkgs/triton/language/core.py new file mode 100644 index 0000000..6720442 --- /dev/null +++ b/pkgs/triton/language/core.py @@ -0,0 +1,1729 @@ +from __future__ import annotations + +from contextlib import contextmanager +from enum import Enum +from functools import wraps +from typing import Callable, List, Sequence, TypeVar + +import triton +from . import semantic +from triton._C.libtriton.triton import ir + +T = TypeVar('T') + +TRITON_MAX_TENSOR_NUMEL = 131072 + +TRITON_BUILTIN = "__triton_builtin__" + + +def builtin(fn: T) -> T: + """Mark a function as a builtin.""" + assert callable(fn) + + @wraps(fn) + def wrapper(*args, **kwargs): + if "_builder" not in kwargs or kwargs["_builder"] is None: + raise ValueError( + "Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)" + ) + return fn(*args, **kwargs) + + setattr(wrapper, TRITON_BUILTIN, True) + + return wrapper + + +def is_builtin(fn) -> bool: + """Is this a registered triton builtin function?""" + return getattr(fn, TRITON_BUILTIN, False) + + +def _to_tensor(x, builder): + if isinstance(x, bool): + return tensor(builder.get_int1(x), int1) + # Note: compile-time const integers are represented by unsigned values + elif isinstance(x, int): + if -2**31 <= x < 2**31: + return tensor(builder.get_int32(x), int32) + elif 2**31 <= x < 2**32: + return tensor(builder.get_int32(x), uint32) + elif -2**63 <= x < 2**63: + return tensor(builder.get_int64(x), int64) + elif 2**63 <= x < 2**64: + return tensor(builder.get_int64(x), uint64) + else: + raise RuntimeError(f'Nonrepresentable integer {x}.') + elif isinstance(x, float): + min_float32 = 2 ** -126 + max_float32 = (2 - 2**-23) * 2**127 + abs_x = __builtins__['abs'](x) + if abs_x == float("inf") or\ + abs_x == 0.0 or \ + x != x or \ + min_float32 <= abs_x <= max_float32: + return tensor(builder.get_fp32(x), float32) + else: + return tensor(builder.get_fp64(x), float64) + + elif isinstance(x, constexpr): + return _to_tensor(x.value, builder) + elif isinstance(x, tensor): + return x + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +class dtype: + SINT_TYPES = ['int8', 'int16', 'int32', 'int64'] + UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64'] + FP_TYPES = ['fp8e4', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64'] + STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64'] + OTHER_TYPES = ['void'] + + class SIGNEDNESS(Enum): + SIGNED = 0 + UNSIGNED = 1 + + def __init__(self, name): + self.name = name + assert name in dtype.SINT_TYPES + dtype.UINT_TYPES + dtype.FP_TYPES + dtype.OTHER_TYPES, name + if name in dtype.SINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.SIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.UINT_TYPES: + self.int_signedness = dtype.SIGNEDNESS.UNSIGNED + self.int_bitwidth = int(name.split('int')[-1]) + self.primitive_bitwidth = self.int_bitwidth + elif name in dtype.FP_TYPES: + if name == 'fp8e4': + self.fp_mantissa_width = 3 + self.primitive_bitwidth = 8 + elif name == 'fp8e5': + self.fp_mantissa_width = 2 + self.primitive_bitwidth = 8 + elif name == 'fp16': + self.fp_mantissa_width = 10 + self.primitive_bitwidth = 16 + elif name == 'bf16': + self.fp_mantissa_width = 7 + self.primitive_bitwidth = 16 + elif name == 'fp32': + self.fp_mantissa_width = 23 + self.primitive_bitwidth = 32 + elif name == 'fp64': + self.fp_mantissa_width = 53 + self.primitive_bitwidth = 64 + else: + raise RuntimeError(f'Unsupported floating-point type {name}') + elif name == 'void': + self.primitive_bitwidth = 0 + + def is_fp8(self): + return 'fp8' in self.name + + def is_fp16(self): + return self.name == 'fp16' + + def is_bf16(self): + return self.name == 'bf16' + + def is_fp32(self): + return self.name == 'fp32' + + def is_fp64(self): + return self.name == 'fp64' + + def is_int1(self): + return self.name == 'int1' + + def is_int8(self): + return self.name == 'int8' + + def is_int16(self): + return self.name == 'int16' + + def is_int32(self): + return self.name == 'int32' + + def is_int64(self): + return self.name == 'int64' + + def is_uint8(self): + return self.name == 'uint8' + + def is_uint16(self): + return self.name == 'uint16' + + def is_uint32(self): + return self.name == 'uint32' + + def is_uint64(self): + return self.name == 'uint64' + + def is_floating(self): + return self.name in dtype.FP_TYPES + + def is_standard_floating(self): + return self.name in dtype.STANDARD_FP_TYPES + + def is_int_signed(self): + return self.name in dtype.SINT_TYPES + + def is_int_unsigned(self): + return self.name in dtype.UINT_TYPES + + def is_int(self): + return self.name in dtype.SINT_TYPES + dtype.UINT_TYPES + + def is_bool(self): + return self.is_int1() + + @staticmethod + def is_void(): + raise RuntimeError("Not implemented") + + @staticmethod + def is_block(): + return False + + @staticmethod + def is_ptr(): + return False + + def __eq__(self, other: dtype): + if not isinstance(other, dtype): + return False + return self.name == other.name + + def __ne__(self, other: dtype): + return not self.__eq__(other) + + def __hash__(self): + return hash((self.name,)) + + @property + def scalar(self): + return self + + def to_ir(self, builder: ir.builder) -> ir.type: + if self.name == 'void': + return builder.get_void_ty() + elif self.name == 'int1': + return builder.get_int1_ty() + elif self.name in ('int8', 'uint8'): + return builder.get_int8_ty() + elif self.name in ('int16', 'uint16'): + return builder.get_int16_ty() + elif self.name in ('int32', 'uint32'): + return builder.get_int32_ty() + elif self.name in ('int64', 'uint64'): + return builder.get_int64_ty() + elif self.name == 'fp8e5': + return builder.get_fp8e5_ty() + elif self.name == 'fp8e4': + return builder.get_fp8e4_ty() + elif self.name == 'fp16': + return builder.get_half_ty() + elif self.name == 'bf16': + return builder.get_bf16_ty() + elif self.name == 'fp32': + return builder.get_float_ty() + elif self.name == 'fp64': + return builder.get_double_ty() + raise ValueError(f'fail to convert {self} to ir type') + + def __str__(self): + return self.name + + @property + def cache_key_part(self) -> str: + """See cache_key_part() in triton.cc.""" + return self.name + + def __repr__(self): + return f'triton.language.{self.name}' + + +class pointer_type(dtype): + def __init__(self, element_ty: dtype, address_space: int = 1): + if not isinstance(element_ty, dtype): + raise TypeError('element_ty is a {type(element_ty).__name__}.') + self.element_ty = element_ty + self.address_space = address_space + + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.pointer_type: + return builder.get_ptr_ty(self.element_ty.to_ir(builder), 1) + + def __str__(self): + return f'pointer<{self.element_ty}>' + + def __repr__(self): + return self.__str__() + + def is_ptr(self): + return True + + def __eq__(self, other: pointer_type) -> bool: + if not isinstance(other, pointer_type): + return False + return self.element_ty == other.element_ty and self.address_space == other.address_space + + def __ne__(self, other: pointer_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self + + +class block_type(dtype): + def __init__(self, element_ty: dtype, shape: List): + self.element_ty = element_ty + + # Note that block_type's shape is a list of int + # while tensor's shape is a list of constexpr. + + # shape can be empty ([]) when an input is a 0D tensor. + if not shape: + raise TypeError('0d block_type is forbidden') + if isinstance(shape[0], constexpr): + shape = [s.value for s in shape] + + self.shape = shape + self.numel = 1 + for s in self.shape: + self.numel *= s + if self.numel > TRITON_MAX_TENSOR_NUMEL: + raise ValueError(f"numel ({self.numel}) exceeds triton maximum tensor numel ({TRITON_MAX_TENSOR_NUMEL})") + + self.name = self.__str__() + + def to_ir(self, builder: ir.builder) -> ir.block_type: + return builder.get_block_ty(self.element_ty.to_ir(builder), self.shape) + + def __str__(self): + return f'<{self.shape}, {self.element_ty}>' + + def __repr__(self): + return self.__str__() + + def is_block(self): + return True + + def get_block_shapes(self) -> List[int]: + return self.shape + + def __eq__(self, other: block_type) -> bool: + if not isinstance(other, block_type): + return False + return self.element_ty == other.element_ty and self.shape == other.shape + + def __ne__(self, other: block_type) -> bool: + return not self.__eq__(other) + + @property + def scalar(self): + return self.element_ty + + +class function_type(dtype): + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: + self.ret_types = ret_types + self.param_types = param_types + + def __str__(self): + return f'fn ({self.param_types}) -> {self.ret_types}' + + def to_ir(self, builder: ir.builder): + ir_param_types = [ty.to_ir(builder) for ty in self.param_types] + ret_types = [ret_type.to_ir(builder) for ret_type in self.ret_types] + return builder.get_function_ty(ir_param_types, ret_types) + + +# scalar types +void = dtype('void') +int1 = dtype('int1') +int8 = dtype('int8') +int16 = dtype('int16') +int32 = dtype('int32') +int64 = dtype('int64') +uint8 = dtype('uint8') +uint16 = dtype('uint16') +uint32 = dtype('uint32') +uint64 = dtype('uint64') +float8e5 = dtype('fp8e5') +float8e4 = dtype('fp8e4') +float16 = dtype('fp16') +bfloat16 = dtype('bf16') +float32 = dtype('fp32') +float64 = dtype('fp64') +# pointer types +pi32_t = pointer_type(int32) + +# ----------------------- +# constexpr +# ----------------------- + + +class constexpr: + """ + This class is used to store a value that is known at compile-time. + """ + + def __init__(self, value): + if isinstance(value, constexpr): + self.value = value.value + else: + self.value = value + + def __repr__(self) -> str: + return f"constexpr[{self.value}]" + + def __add__(self, other): + return constexpr(self.value + other.value) + + def __radd__(self, other): + return constexpr(other.value + self.value) + + def __sub__(self, other): + return constexpr(self.value - other.value) + + def __rsub__(self, other): + return constexpr(other.value - self.value) + + def __mul__(self, other): + return constexpr(self.value * other.value) + + def __mod__(self, other): + return constexpr(self.value % other.value) + + def __rmul__(self, other): + return constexpr(other.value * self.value) + + def __truediv__(self, other): + return constexpr(self.value / other.value) + + def __rtruediv__(self, other): + return constexpr(other.value / self.value) + + def __floordiv__(self, other): + return constexpr(self.value // other.value) + + def __rfloordiv__(self, other): + return constexpr(other.value // self.value) + + def __gt__(self, other): + return constexpr(self.value > other.value) + + def __rgt__(self, other): + return constexpr(other.value > self.value) + + def __ge__(self, other): + return constexpr(self.value >= other.value) + + def __rge__(self, other): + return constexpr(other.value >= self.value) + + def __lt__(self, other): + return constexpr(self.value < other.value) + + def __rlt__(self, other): + return constexpr(other.value < self.value) + + def __le__(self, other): + return constexpr(self.value <= other.value) + + def __rle__(self, other): + return constexpr(other.value <= self.value) + + def __eq__(self, other): + return constexpr(self.value == other.value) + + def __ne__(self, other): + return constexpr(self.value != other.value) + + def __bool__(self): + return bool(self.value) + + def __neg__(self): + return constexpr(-self.value) + + def __and__(self, other): + return constexpr(self.value & other.value) + + def logical_and(self, other): + return constexpr(self.value and other.value) + + def __or__(self, other): + return constexpr(self.value | other.value) + + def __xor__(self, other): + return constexpr(self.value ^ other.value) + + def logical_or(self, other): + return constexpr(self.value or other.value) + + def __pos__(self): + return constexpr(+self.value) + + def __invert__(self): + return constexpr(~self.value) + + def __pow__(self, other): + return constexpr(self.value ** other.value) + + def __rshift__(self, other): + return constexpr(self.value >> other.value) + + def __lshift__(self, other): + return constexpr(self.value << other.value) + + def __not__(self): + return constexpr(not self.value) + + def __call__(self, *args, **kwds): + return self.value(*args, **kwds) + + +class tensor: + def __init__(self, handle, type: dtype): + # IR handle + self.handle = handle + # Block shape + self.shape = (1, ) + if type.is_block(): + self.shape = type.shape + self.numel = 1 + for s in self.shape: + self.numel *= s + self.numel = constexpr(self.numel) + self.type = type # Tensor type (can be block_type) + # Following the practice in pytorch, dtype is scalar type + self.dtype = type.scalar + self.shape = [constexpr(s) for s in self.shape] + + def __hash__(self): + return hash('language.core.tensor') + + def __str__(self) -> str: + # ex. "float32[3,4]" + return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']' + + @builtin + def __add__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.add(self, other, _builder) + + def __radd__(self, other, _builder=None): + return self.__add__(other, _builder=_builder) + + @builtin + def __sub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(self, other, _builder) + + def __rsub__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.sub(other, self, _builder) + + @builtin + def __mul__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mul(self, other, _builder) + + def __rmul__(self, other, _builder=None): + return self.__mul__(other, _builder=_builder) + + @builtin + def __truediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(self, other, _builder) + + def __rtruediv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.truediv(other, self, _builder) + + @builtin + def __floordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(self, other, _builder) + + @builtin + def __rfloordiv__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.floordiv(other, self, _builder) + + @builtin + def __mod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(self, other, _builder) + + @builtin + def __rmod__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.mod(other, self, _builder) + + # unary operators + @builtin + def __neg__(self, _builder=None): + return semantic.minus(self, _builder) + + @builtin + def __invert__(self, _builder=None): + return semantic.invert(self, _builder) + + # bitwise operators + + @builtin + def __and__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(self, other, _builder) + + @builtin + def __rand__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.and_(other, self, _builder) + + @builtin + def __or__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(self, other, _builder) + + @builtin + def __ror__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.or_(other, self, _builder) + + @builtin + def __xor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(self, other, _builder) + + @builtin + def __rxor__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.xor_(other, self, _builder) + + @builtin + def __lshift__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.shl(self, other, _builder) + + @builtin + def __rlshift__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.shl(other, self, _builder) + + @builtin + def __rshift__(self, other, _builder=None): + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(self, other, _builder) + else: + return semantic.lshr(self, other, _builder) + + @builtin + def __rrshift__(self, other, _builder=None): + other = _to_tensor(other, _builder) + if self.dtype.is_int_signed(): + return semantic.ashr(other, self, _builder) + else: + return semantic.lshr(other, self, _builder) + + # comparison operators + + # > + @builtin + def __gt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(self, other, _builder) + + @builtin + def __rgt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_than(other, self, _builder) + + # >= + @builtin + def __ge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(self, other, _builder) + + @builtin + def __rge__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.greater_equal(other, self, _builder) + + # < + @builtin + def __lt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(self, other, _builder) + + @builtin + def __rlt__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_than(other, self, _builder) + + # <= + @builtin + def __le__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(self, other, _builder) + + @builtin + def __rle__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.less_equal(other, self, _builder) + + # == + @builtin + def __eq__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.equal(self, other, _builder) + + @builtin + def __ne__(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.not_equal(self, other, _builder) + + @builtin + def logical_and(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_and(self, other, _builder) + + @builtin + def logical_or(self, other, _builder=None): + other = _to_tensor(other, _builder) + return semantic.logical_or(self, other, _builder) + + # note: __not__ isn't actually a magic method in python + # but it's ok because our ASTVisitor handles it + @builtin + def __not__(self, _builder=None): + return semantic.not_(self, _builder) + + @builtin + def __getitem__(self, slices, _builder=None): + if isinstance(slices, slice): + slices = [slices] + ret = self + for dim, sl in enumerate(slices): + if isinstance(sl, constexpr) and sl.value is None: + ret = semantic.expand_dims(ret, dim, _builder) + elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: + pass + else: + assert False, f"unsupported tensor index: {sl}" + return ret + + @property + def T(self): + assert False, "Transposition must be created by the AST Visitor" + + @builtin + def to(self, dtype, bitcast=False, _builder=None): + if isinstance(bitcast, constexpr): + bitcast = bitcast.value + if bitcast: + return semantic.bitcast(self, dtype, _builder) + return semantic.cast(self, dtype, _builder) + + +# ----------------------- +# SPMD Programming Model +# ----------------------- +def _constexpr_to_value(v): + if isinstance(v, constexpr): + return v.value + return v + + +@builtin +def program_id(axis, _builder=None): + """ + Returns the id of the current program instance along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. + :type axis: int + """ + # if axis == -1: + # pid0 = program_id(0, _builder) + # pid1 = program_id(1, _builder) + # pid2 = program_id(2, _builder) + # npg0 = num_programs(0, _builder) + # npg1 = num_programs(0, _builder) + # return pid0 + pid1*npg0 + pid2*npg0*npg1 + axis = _constexpr_to_value(axis) + return semantic.program_id(axis, _builder) + + +@builtin +def num_programs(axis, _builder=None): + """ + Returns the number of program instances launched along the given :code:`axis`. + + :param axis: The axis of the 3D launch grid. Has to be either 0, 1 or 2. + :type axis: int + """ + axis = _constexpr_to_value(axis) + return semantic.num_programs(axis, _builder) + + +# ----------------------- +# Block Initialization +# ----------------------- + + +@builtin +def arange(start, end, _builder=None): + """ + Returns contiguous values within the left-closed and right-open interval [:code:`start`, :code:`end`). \ + End - Start must be less than or equal to TRITON_MAX_TENSOR_NUMEL = 131072 + + :param start: Start of the interval. Must be a power of two. + :type start: int32 + :param end: End of the interval. Must be a power of two > start. + :type end: int32 + """ + start = _constexpr_to_value(start) + end = _constexpr_to_value(end) + return semantic.arange(start, end, _builder) + + +def _shape_check_impl(shape): + shape = _constexpr_to_value(shape) + for i, d in enumerate(shape): + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + return [_constexpr_to_value(x) for x in shape] + + +@builtin +def full(shape, value, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :value value: A scalar value to fill the array with + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + shape = _shape_check_impl(shape) + value = _constexpr_to_value(value) + dtype = _constexpr_to_value(dtype) + return semantic.full(shape, value, dtype, _builder) + + +@builtin +def ones(shape, dtype, _builder=None): + """ + Returns a tensor filled with the scalar value 1 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + for i, d in enumerate(shape): + if not isinstance(d, constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + shape = [x.value for x in shape] + dtype = _constexpr_to_value(dtype) + return semantic.ones(shape, dtype, _builder) + + +# ----------------------- +# Shape Manipulation +# ----------------------- + + +@builtin +def broadcast(input, other, _builder=None): + """ + Tries to broadcast the two given blocks to a common compatible shape. + + :param input: The first input tensor. + :type input: Block + :param other: The second input tensor. + :type other: Block + """ + return semantic.broadcast_impl_value(input, other, _builder) + + +@builtin +def broadcast_to(input, shape, _builder=None): + """ + Tries to broadcast the given tensor to a new :code:`shape`. + + :param input: The input tensor. + :type input: Block + :param shape: The desired shape. + :type shape: Tuple[int] + """ + shape = _shape_check_impl(shape) + return semantic.broadcast_impl_shape(input, shape, _builder) + + +@builtin +def trans(input, _builder=None): + return semantic.trans(input, _builder) + + +@builtin +def cat(input, other, can_reorder=False, _builder=None): + """ + Concatenate the given blocks + + :param input: The first input tensor. + :type input: + :param other: The second input tensor. + :type other: + :param reorder: Compiler hint. If true, the compiler is + allowed to reorder elements while concatenating inputs. + Only use if the order does not matter (e.g., result is + only used in reduction ops) + """ + return semantic.cat(input, other, can_reorder, _builder) + + +@builtin +def view(input, shape, _builder=None): + """ + Returns a tensor with the same elements as `input` but a different shape. + The order of the elements may not be preserved. + + :param input: The input tensor. + :type input: + :param shape: The desired shape. + :type shape: Tuple[int] + + """ + shape = _shape_check_impl(shape) + return semantic.view(input, shape, _builder) + + +@builtin +def reshape(input, shape, _builder=None): + shape = _shape_check_impl(shape) + return semantic.reshape(input, shape, _builder) + + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims recieved duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + +# ----------------------- +# Linear Algebra +# ----------------------- + + +@builtin +def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None): + """ + Returns the matrix product of two blocks. + + The two blocks must be two-dimensional and have compatible inner dimensions. + + :param input: The first tensor to be multiplied. + :type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + :param other: The second tensor to be multiplied. + :type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} + """ + allow_tf32 = _constexpr_to_value(allow_tf32) + out_dtype = _constexpr_to_value(out_dtype) + return semantic.dot(input, other, allow_tf32, out_dtype, _builder) + + +# ----------------------- +# Non-Atomic Memory Operations +# ----------------------- + + +@builtin +def load(pointer, mask=None, other=None, boundary_check=tuple(), padding_option="", cache_modifier="", + eviction_policy="", volatile=False, _builder=None): + """ + Return a tensor of data whose values are loaded from memory at location defined by `pointer`: + (1) `pointer` could be a single element pointer, then a scalar will be loaded + - `mask` and `other` must be scalar too + - `other` is implicitly typecast to `pointer.dtype.element_ty` + - `boundary_check` and `padding_option` must be empty + (2) `pointer` could be element-wise tensor of pointers, in which case: + - `mask` and `other` are implicitly broadcast to `pointer.shape` + - `other` is implicitly typecast to `pointer.dtype.element_ty` + - `boundary_check` and `padding_option` must be empty + (3) `pointer` could be a block pointer defined by `make_block_ptr`, in which case: + - `mask` and `other` must be None + - `boundary_check` and `padding_option` can be specified to control the behavior of out-of-bound access + + :param pointer: Pointer to the data to be loaded + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param mask: if `mask[idx]` is false, do not load the data at address `pointer[idx]` + (must be `None` with block pointers) + :type mask: Block of `triton.int1`, optional + :param other: if `mask[idx]` is false, return `other[idx]` + :type other: Block, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param padding_option: should be one of {"", "zero", "nan"}, do padding while out of bound + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + :param volatile: changes volatile option in NVIDIA PTX + :type volatile: bool, optional + """ + # `mask` and `other` can be constexpr + if _constexpr_to_value(mask) is not None: + mask = _to_tensor(mask, _builder) + if _constexpr_to_value(other) is not None: + other = _to_tensor(other, _builder) + padding_option = _constexpr_to_value(padding_option) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + volatile = _constexpr_to_value(volatile) + return semantic.load(pointer, mask, other, boundary_check, padding_option, cache_modifier, eviction_policy, + volatile, _builder) + + +@builtin +def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", eviction_policy="", _builder=None): + """ + Store a tensor of data into memory locations defined by `pointer`: + (1) `pointer` could be a single element pointer, then a scalar will be stored + - `mask` must be scalar too + - `boundary_check` and `padding_option` must be empty + (2) `pointer` could be element-wise tensor of pointers, in which case: + - `mask` is implicitly broadcast to `pointer.shape` + - `boundary_check` must be empty + (3) or `pointer` could be a block pointer defined by `make_block_ptr`, in which case: + - `mask` must be None + - `boundary_check` can be specified to control the behavior of out-of-bound access + `value` is implicitly broadcast to `pointer.shape` and typecast to `pointer.dtype.element_ty`. + + :param pointer: The memory location where the elements of `value` are stored + :type pointer: `triton.PointerType`, or block of `dtype=triton.PointerType` + :param value: The tensor of elements to be stored + :type value: Block + :param mask: If `mask[idx]` is false, do not store `value[idx]` at `pointer[idx]` + :type mask: Block of triton.int1, optional + :param boundary_check: tuple of integers, indicating the dimensions which should do the boundary check + :type boundary_check: tuple of ints, optional + :param cache_modifier: changes cache option in NVIDIA PTX + :type cache_modifier: str, optional + :param eviction_policy: changes eviction policy in NVIDIA PTX + :type eviction_policy: str, optional + """ + # `value` can be constexpr + value = _to_tensor(value, _builder) + if _constexpr_to_value(mask) is not None: + mask = _to_tensor(mask, _builder) + cache_modifier = _constexpr_to_value(cache_modifier) + eviction_policy = _constexpr_to_value(eviction_policy) + return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder) + + +@builtin +def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None): + """ + Returns a pointer to a block in a parent tensor + + :param base: The base pointer to the parent tensor + :param shape: The shape of the parent tensor + :param strides: The strides of the parent tensor + :param offsets: The offsets to the block + :param block_shape: The shape of the block + :param order: The order of the original data format + """ + return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) + + +@builtin +def advance(base: tensor, offsets, _builder=None): + """ + Advance a block pointer + + :param base: the block pointer to advance + :param offsets: the offsets to advance, a tuple by dimension + """ + return semantic.advance(base, offsets, _builder) + +# ----------------------- +# Atomic Memory Operations +# ----------------------- + + +def _add_atomic_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Performs an atomic {name} at the memory location specified by :code:`pointer`. + + Return the data stored at :code:`pointer` before the atomic operation. + + :param pointer: The memory locations to compare-and-swap. + :type pointer: Block of dtype=triton.PointerDType + :param cmp: The values expected to be found in the atomic object + :type cmp: Block of dtype=`pointer.dtype.element_ty` + :param val: The values to copy in case the expected value matches the contained value. + :type val: Block of dtype=`pointer.dtype.element_ty` + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@builtin +@_add_atomic_docstr("compare-and-swap") +def atomic_cas(pointer, cmp, val, _builder=None): + cmp = _to_tensor(cmp, _builder) + val = _to_tensor(val, _builder) + return semantic.atomic_cas(pointer, cmp, val, _builder) + + +@builtin +@_add_atomic_docstr("exchange") +def atomic_xchg(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_xchg(pointer, val, mask, _builder) + + +@builtin +@_add_atomic_docstr("add") +def atomic_add(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_add(pointer, val, mask, _builder) + + +@builtin +@_add_atomic_docstr("max") +def atomic_max(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_max(pointer, val, mask, _builder) + + +@builtin +@_add_atomic_docstr("min") +def atomic_min(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_min(pointer, val, mask, _builder) + + +@builtin +@_add_atomic_docstr("logical and") +def atomic_and(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_and(pointer, val, mask, _builder) + + +@builtin +@_add_atomic_docstr("logical or") +def atomic_or(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_or(pointer, val, mask, _builder) + + +@builtin +@_add_atomic_docstr("logical xor") +def atomic_xor(pointer, val, mask=None, _builder=None): + val = _to_tensor(val, _builder) + return semantic.atomic_xor(pointer, val, mask, _builder) + + +# ----------------------- +# Conditioning +# ----------------------- + +@builtin +def where(condition, x, y, _builder=None): + """ + Returns a tensor of elements from either :code:`x` or :code:`y`, depending on :code:`condition`. + + Note that :code:`x` and :code:`y` are always evaluated regardless of the value of :code:`condition`. + + If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead. + + The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`. + :code:`x` and :code:`y` must have the same data type. + + :param condition: When True (nonzero), yield x, otherwise yield y. + :type condition: Block of triton.bool + :param x: values selected at indices where condition is True. + :param y: values selected at indices where condition is False. + """ + condition = _to_tensor(condition, _builder) + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.where(condition, x, y, _builder) + + +# ----------------------- +# Math +# ----------------------- + +@builtin +def umulhi(x, y, _builder=None): + x = _to_tensor(x, _builder) + y = _to_tensor(y, _builder) + return semantic.umulhi(x, y, _builder) + + +@builtin +def fdiv(x, y, ieee_rounding=False, _builder=None): + ieee_rounding = _constexpr_to_value(ieee_rounding) + return semantic.fdiv(x, y, ieee_rounding, _builder) + + +def _add_math_1arg_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Computes the element-wise {name} of :code:`x`. + + :param x: the input values + :type x: Block + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@builtin +@_add_math_1arg_docstr("exponential") +def exp(x, _builder=None): + return semantic.exp(x, _builder) + + +@builtin +@_add_math_1arg_docstr("natural logarithm") +def log(x, _builder=None): + return semantic.log(x, _builder) + + +@builtin +@_add_math_1arg_docstr("cosine") +def cos(x, _builder=None): + return semantic.cos(x, _builder) + + +@builtin +@_add_math_1arg_docstr("sine") +def sin(x, _builder=None): + return semantic.sin(x, _builder) + + +@builtin +@_add_math_1arg_docstr("square root") +def sqrt(x, _builder=None): + return semantic.sqrt(x, _builder) + + +@builtin +@_add_math_1arg_docstr("absolute value") +def abs(x, _builder=None): + return semantic.abs(x, _builder) + + +# ----------------------- +# Reductions +# ----------------------- + +def _add_reduction_docstr(name: str) -> Callable[[T], T]: + + def _decorator(func: T) -> T: + docstr = """ + Returns the {name} of all elements in the :code:`input` tensor along the provided :code:`axis` + + :param input: the input values + :param axis: the dimension along which the reduction should be done + """ + func.__doc__ = docstr.format(name=name) + return func + + return _decorator + + +@contextmanager +def _insertion_guard(builder): + ip = builder.get_insertion_point() + yield + builder.restore_insertion_point(ip) + + +@builtin +def reduce(input, axis, combine_fn, _builder=None, _generator=None): + """Applies the combine_fn to all elements in :code:`input` tensors along the provided :code:`axis` + + :param input: the input tensor, or tuple of tensors + :param axis: the dimension along which the reduction should be done + :param combine_fn: a function to combine two groups of scalar tensors (must be marked with @triton.jit) + + """ + if isinstance(input, tensor): + return reduce((input,), axis, combine_fn, + _builder=_builder, _generator=_generator)[0] + + def make_combine_region(reduce_op): + in_scalar_tys = [t.type.scalar for t in input] + prototype = function_type(in_scalar_tys, in_scalar_tys * 2) + + region = reduce_op.get_region(0) + with _insertion_guard(_builder): + param_types = [ty.to_ir(_builder) for ty in prototype.param_types] + block = _builder.create_block_with_parent(region, param_types) + args = [tensor(block.arg(i), ty) + for i, ty in enumerate(prototype.param_types)] + results = _generator.call_JitFunction(combine_fn, args, kwargs={}) + if isinstance(results, tensor): + handles = [results.handle] + else: + handles = [r.handle for r in results] + _builder.create_reduce_ret(*handles) + if axis is not None: + axis = _constexpr_to_value(axis) + return semantic.reduction(input, axis, make_combine_region, _builder) + + +@builtin +def _promote_reduction_input(t, _builder=None): + scalar_ty = t.type.scalar + # input is extended to 32-bits if necessary + # this increases numerical accuracy and can be done pretty much for free + # on GPUs + if scalar_ty.is_int() and scalar_ty.int_bitwidth < 32: + return t.to(int32, _builder=_builder) + + # hardware doesn't support FMAX, FMIN, CMP for bfloat16 + if scalar_ty is bfloat16: + return t.to(float32, _builder=_builder) + + return t + + +@builtin +def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): + axis = _constexpr_to_value(axis) + n = input.shape[axis] + index = arange(0, n, _builder=_builder) + + if len(input.shape) > 1: + # Broadcast index across the non-reduced axes + axes_to_expand = [constexpr(d) for d in range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) + index = broadcast_to(index, input.shape, _builder=_builder) + + rvalue, rindices = reduce((input, index), axis, combine_fn, + _builder=_builder, _generator=_generator) + return rindices + + +@triton.jit +def minimum(x, y): + """ + Computes the element-wise minimum of :code:`x` and :code:`y`. + + :param input: the first input tensor + :type input: Block + :param other: the second input tensor + :type other: Block + """ + return where(x < y, x, y) + + +@triton.jit +def maximum(x, y): + """ + Computes the element-wise maximum of :code:`x` and :code:`y`. + + :param input: the first input tensor + :type input: Block + :param other: the second input tensor + :type other: Block + """ + return where(x > y, x, y) + + +@triton.jit +def _max_combine(a, b): + return maximum(a, b) + + +@triton.jit +@_add_reduction_docstr("maximum") +def max(input, axis=None): + input = _promote_reduction_input(input) + return reduce(input, axis, _max_combine) + + +@triton.jit +def _argmax_combine(value1, index1, value2, index2): + gt = value1 > value2 + lt = value1 < value2 + index_min = minimum(index1, index2) + index_ret = where(gt, index1, where(lt, index2, index_min)) + value_ret = maximum(value1, value2) + return value_ret, index_ret + + +@triton.jit +@_add_reduction_docstr("maximum index") +def argmax(input, axis): + input = _promote_reduction_input(input) + return _argreduce(input, axis, _argmax_combine) + + +@triton.jit +def _min_combine(a, b): + # TODO: minimum/maximum doesn't get lowered to fmin/fmax... + return minimum(a, b) + + +@triton.jit +@_add_reduction_docstr("minimum") +def min(input, axis=None): + input = _promote_reduction_input(input) + return reduce(input, axis, _min_combine) + + +@triton.jit +def _argmin_combine(value1, index1, value2, index2): + lt = value1 < value2 + gt = value1 > value2 + index_min = minimum(index1, index2) + index_ret = where(lt, index1, where(gt, index2, index_min)) + value_ret = minimum(value1, value2) + return value_ret, index_ret + + +@triton.jit +@_add_reduction_docstr("minimum index") +def argmin(input, axis): + input = _promote_reduction_input(input) + return _argreduce(input, axis, _argmin_combine) + + +@triton.jit +def _sum_combine(a, b): + return a + b + + +@triton.jit +@_add_reduction_docstr("sum") +def sum(input, axis=None): + input = _promote_reduction_input(input) + return reduce(input, axis, _sum_combine) + + +@triton.jit +def _xor_combine(a, b): + return a ^ b + + +@builtin +@_add_reduction_docstr("xor sum") +def xor_sum(input, axis=None, _builder=None, _generator=None): + scalar_ty = input.type.scalar + if not scalar_ty.is_int(): + raise ValueError("xor_sum only supported for integers") + + input = _promote_reduction_input(input, _builder=_builder) + return reduce(input, axis, _xor_combine, + _builder=_builder, _generator=_generator) + + +# ----------------------- +# Compiler Hint Ops +# ----------------------- + + +@builtin +def debug_barrier(_builder=None): + ''' + Insert a barrier to synchronize all threads in a block. + ''' + return semantic.debug_barrier(_builder) + + +@builtin +def multiple_of(input, values, _builder=None): + """ + Let the compiler knows that the values in :code:`input` are all multiples of :code:`value`. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.multiple_of(input, values) + + +@builtin +def max_contiguous(input, values, _builder=None): + """ + Let the compiler knows that the `value` first values in :code:`input` are contiguous. + """ + if isinstance(values, constexpr): + values = [values] + for i, d in enumerate(values): + if not isinstance(d, constexpr): + raise TypeError(f"values element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + values = [x.value for x in values] + return semantic.max_contiguous(input, values) + +# ----------------------- +# Debugging functions +# ----------------------- + + +@builtin +def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None): + ''' + Print the values at compile time. The parameters are the same as the builtin :code:`print`. + ''' + pass + + +@builtin +def static_assert(cond, msg="", _builder=None): + ''' + Assert the condition at compile time. The parameters are the same as the builtin :code:`assert`. + ''' + pass + + +@builtin +def device_print(prefix, *args, _builder=None): + ''' + Print the values at runtime from the device. + + :param prefix: a prefix to print before the values. This is required to be a string literal. + :param args: the values to print. They can be any tensor or scalar. + ''' + import string + prefix = _constexpr_to_value(prefix) + assert isinstance(prefix, str), f"{prefix} is not string" + b_ascii = True + for ch in prefix: + if ch not in string.printable: + b_ascii = False + break + assert b_ascii, f"{prefix} is not an ascii string" + new_args = [] + for arg in args: + new_args.append(_to_tensor(arg, _builder)) + return semantic.device_print(prefix, new_args, _builder) + + +@builtin +def device_assert(cond, msg="", _builder=None): + ''' + Assert the condition at runtime from the device. + + :param cond: the condition to assert. This is required to be a boolean tensor. + :param msg: the message to print if the assertion fails. This is required to be a string literal. + ''' + msg = _constexpr_to_value(msg) + import inspect + frame = inspect.currentframe() + module = inspect.getmodule(frame) + # The triton function module doesn't have the name attribute. + # We use this trick to find the caller. + while hasattr(module, "__name__"): + frame = frame.f_back + module = inspect.getmodule(frame) + lineno = 0 + func_name = 'unknown' + file_name = 'unknown' + if frame is not None and frame.f_back is not None: + func_name = frame.f_code.co_name + file_name = frame.f_back.f_code.co_filename + # TODO: The line number currently indicates the line + # where the triton function is called but not where the + # device_assert is called. Need to enhance this. + lineno = frame.f_back.f_lineno + return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder) + + +# ----------------------- +# Iterators +# ----------------------- + + +class static_range: + + """ + Iterator that counts upward forever. + + .. highlight:: python + .. code-block:: python + + @triton.jit + def kernel(...): + for i in tl.static_range(10): + ... + :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of + :code:`triton.jit` functions. In addition, it also guides the compiler to unroll the loop aggressively. + :param arg1: the start value. + :param arg2: the end value. + :param step: the step value. + """ + + def __init__(self, arg1, arg2=None, step=None): + assert isinstance(arg1, constexpr) + if step is None: + self.step = constexpr(1) + else: + assert isinstance(step, constexpr) + self.step = step + if arg2 is None: + self.start = constexpr(0) + self.end = arg1 + else: + assert isinstance(arg2, constexpr) + self.start = arg1 + self.end = arg2 + + def __iter__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + def __next__(self): + raise RuntimeError("static_range can only be used in @triton.jit'd functions") + + +# ----------------------- +# Extern functions +# ----------------------- + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _builder: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_type = arg_type_symbol_dict[arg_types][1] + if ret_shape: + ret_type = block_type(ret_type, ret_shape) + return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) + + +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None): + ''' + Dispatch an elementwise function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _builder: the builder + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + ret_shape = None + arg_types = [] + for i in range(len(dispatch_args)): + dispatch_args[i] = _to_tensor(dispatch_args[i], _builder) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if len(arg_types) > 0: + arg_types = tuple(arg_types) + arithmetic_check = True + # If there's a type tuple that is not supported by the library, we will do arithmetic check + if arg_types in arg_type_symbol_dict: + arithmetic_check = False + broadcast_arg = dispatch_args[0] + # Get the broadcast shape over all the arguments + for i, item in enumerate(dispatch_args): + _, broadcast_arg = semantic.binary_op_type_checking_impl( + item, broadcast_arg, _builder, arithmetic_check=arithmetic_check) + # Change the shape of each argument based on the broadcast shape + for i in range(len(dispatch_args)): + dispatch_args[i], _ = semantic.binary_op_type_checking_impl( + dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check) + if not all_scalar: + ret_shape = broadcast_arg.shape + func = getattr(_builder, "create_extern_elementwise") + return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder) + + +def extern(fn): + """A decorator for external functions.""" + return builtin(fn) diff --git a/pkgs/triton/language/extra/__init__.py b/pkgs/triton/language/extra/__init__.py new file mode 100644 index 0000000..2fd0ff3 --- /dev/null +++ b/pkgs/triton/language/extra/__init__.py @@ -0,0 +1,3 @@ +from . import cuda + +__all__ = ['cuda'] diff --git a/pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..682a743 Binary files /dev/null and b/pkgs/triton/language/extra/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc b/pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc new file mode 100644 index 0000000..d572724 Binary files /dev/null and b/pkgs/triton/language/extra/__pycache__/cuda.cpython-310.pyc differ diff --git a/pkgs/triton/language/extra/cuda.bc b/pkgs/triton/language/extra/cuda.bc new file mode 100644 index 0000000..4538ac3 Binary files /dev/null and b/pkgs/triton/language/extra/cuda.bc differ diff --git a/pkgs/triton/language/extra/cuda.py b/pkgs/triton/language/extra/cuda.py new file mode 100644 index 0000000..92df37a --- /dev/null +++ b/pkgs/triton/language/extra/cuda.py @@ -0,0 +1,19 @@ +import os + +from .. import core + +__path__ = os.path.dirname(os.path.abspath(__file__)) + + +@core.extern +def globaltimer(_builder=None): + return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [], + {tuple(): ("globaltimer", core.dtype("int64")), + }, is_pure=False, _builder=_builder) + + +@core.extern +def smid(_builder=None): + return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [], + {tuple(): ("smid", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/pkgs/triton/language/math.py b/pkgs/triton/language/math.py new file mode 100644 index 0000000..dd3c65d --- /dev/null +++ b/pkgs/triton/language/math.py @@ -0,0 +1,1537 @@ +import functools +import os + +from . import core +from triton.common.build import cuda_home_dirs, is_corex + + +@functools.lru_cache() +def libdevice_path(): + import torch + third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party") + if is_corex(): + default = cuda_home_dirs() + "/nvvm/libdevice/libdevice.compute_bi.10.bc" + # default = cuda_home_dirs() + "/../../ixsdk/nvvm/libdevice/libdevice.compute_bi.10.bc" + elif torch.version.hip is None: + default = os.path.join(third_party_dir, "cuda", "lib", "libdevice.10.bc") + else: + default = os.path.join(third_party_dir, "rocm", "lib", "bitcode", "cuda2gcn.bc") + return os.getenv("TRITON_LIBDEVICE_PATH", default) + +@core.extern +def clz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def popc(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def byte_perm(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def min(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_min", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umin", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmin", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmin", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fminf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def max(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_max", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umax", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_llmax", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_ullmax", core.dtype("uint64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaxf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmax", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mulhi(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul24(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def brev(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sad(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def abs(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def floor(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp64h(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ceil(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def trunc(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp2(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def saturatef(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rn(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rz(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_rd(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma_ru(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_dividef(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def div_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcp_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sqrt(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def add_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def mul_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2float_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2int_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2uint_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2double_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2double_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2int_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2uint_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int2float_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint2float_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hiloint2double(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2loint(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2hiint(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ll_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float2ull_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ll_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double2ull_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2float_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2float_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ll2double_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rz(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_rd(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ull2double_ru(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def int_as_float(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_int(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def uint_as_float(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def float_as_uint(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def longlong_as_double(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def double_as_longlong(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_sinf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_cosf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log2f(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_logf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_expf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_tanf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_exp10f(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_log10f(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fast_powf(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hadd(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhadd(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rz(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_rd(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sub_ru(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rsqrt_rn(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ffs(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rint(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llrint(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nearbyint(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isnan(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def signbit(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def copysign(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def finitef(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isinf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def nextafter(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sin(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cos(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinpi(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cospi(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tan(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log2(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def exp10(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cosh(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def sinh(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tanh(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan2(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atan(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asin(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acos(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log10(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def log1p(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def acosh(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def asinh(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def atanh(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def expm1(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def hypot(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rhypot(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm3d(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def norm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, arg3, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cbrt(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def rcbrt(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j0(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def j1(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y0(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def y1(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def yn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def jn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i0(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def cyl_bessel_i1(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfinv(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfc(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcx(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def erfcinv(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdfinv(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def normcdf(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def lgamma(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ldexp(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def scalbn(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fmod(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def remainder(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fma(arg0, arg1, arg2, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], + {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def pow(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def tgamma(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def round(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def llround(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def fdim(arg0, arg1, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], + {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def ilogb(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def logb(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) + + +@core.extern +def isfinited(arg0, _builder=None): + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], + {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/pkgs/triton/language/random.py b/pkgs/triton/language/random.py new file mode 100644 index 0000000..a9ddbd8 --- /dev/null +++ b/pkgs/triton/language/random.py @@ -0,0 +1,178 @@ +import triton +from . import core as tl + +PHILOX_KEY_A: tl.constexpr = 0x9E3779B9 +PHILOX_KEY_B: tl.constexpr = 0xBB67AE85 +PHILOX_ROUND_A: tl.constexpr = 0xD2511F53 +PHILOX_ROUND_B: tl.constexpr = 0xCD9E8D57 +N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +# ------------------- +# randint +# ------------------- + + +@triton.jit +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1). + """ + for _ in tl.static_range(n_rounds): + # for _ in range(n_rounds): + # update random state + A = PHILOX_ROUND_A + B = PHILOX_ROUND_B + _c0, _c2 = c0, c2 + c0 = tl.umulhi(B, _c2) ^ c1 ^ k0 + c2 = tl.umulhi(A, _c0) ^ c3 ^ k1 + c1 = B * _c2 + c3 = A * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A + k1 = k1 + PHILOX_KEY_B + return c0, c1, c2, c3 + + +@triton.jit +def philox(seed, c0, c1, c2, c3, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + seed = seed.to(tl.uint64) + seed_hi = ((seed >> 32) & 0xffffffff).to(tl.uint32) + seed_lo = (seed & 0xffffffff).to(tl.uint32) + c0 = c0.to(tl.uint32, bitcast=True) + c1 = c1.to(tl.uint32, bitcast=True) + c2 = c2.to(tl.uint32, bitcast=True) + c3 = c3.to(tl.uint32, bitcast=True) + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +@triton.jit +def randint(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns a single + block of random :code:`int32`. + + If you need multiple streams of random numbers, + using `randint4x` is likely to be faster than calling `randint` 4 times. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +@triton.jit +def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, returns four + blocks of random :code:`int32`. + + This is the maximally efficient entry point + to Triton's Philox pseudo-random number generator. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + # _0 = tl.zeros(offset.shape, offset.dtype) + _0 = offset * 0 + return philox(seed, offset, _0, _0, _0, n_rounds) + + +# ------------------- +# rand +# ------------------- + +# @triton.jit +# def uint32_to_uniform_float(x): +# """ +# Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). +# """ +# two_to_the_minus_32: tl.constexpr = 2.328306e-10 +# return x * two_to_the_minus_32 + +@triton.jit +def uint32_to_uniform_float(x): + """ + Numerically stable function to convert a random uint32 into a random float uniformly sampled in [0, 1). + """ + x = x.to(tl.int32, bitcast=True) + # maximum value such that `MAX_INT * scale < 1.0` (with float rounding) + scale = 4.6566127342e-10 + x = tl.where(x < 0, -x - 1, x) + return x * scale + + +@triton.jit +def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + offset = offset.to(tl.uint32, bitcast=True) + source = randint(seed, offset, n_rounds) + return uint32_to_uniform_float(source) + + +@triton.jit +def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offsets` block, + returns a 4 blocks of random :code:`float32` in :math:`U(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + offsets = offsets.to(tl.uint32, bitcast=True) + i1, i2, i3, i4 = randint4x(seed, offsets, n_rounds) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + u3 = uint32_to_uniform_float(i3) + u4 = uint32_to_uniform_float(i4) + return u1, u2, u3, u4 + +# ------------------- +# randn +# ------------------- + + +@triton.jit +def pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = tl.maximum(1.0e-7, u1) + th = 6.283185307179586 * u2 + r = tl.sqrt(-2.0 * tl.log(u1)) + return r * tl.cos(th), r * tl.sin(th) + + +@triton.jit +def randn(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a block of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + i1, i2, _, _ = randint4x(seed, offset, n_rounds) + u1 = uint32_to_uniform_float(i1) + u2 = uint32_to_uniform_float(i2) + n1, _ = pair_uniform_to_normal(u1, u2) + return n1 + + +@triton.jit +def randn4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): + """ + Given a :code:`seed` scalar and an :code:`offset` block, + returns a 4 blocks of random :code:`float32` in :math:`\\mathcal{N}(0, 1)`. + + :param seed: The seed for generating random numbers. + :param offsets: The offsets to generate random numbers for. + """ + u1, u2, u3, u4 = rand4x(seed, offset, n_rounds) + n1, n2 = pair_uniform_to_normal(u1, u2) + n3, n4 = pair_uniform_to_normal(u3, u4) + return n1, n2, n3, n4 diff --git a/pkgs/triton/language/semantic.py b/pkgs/triton/language/semantic.py new file mode 100644 index 0000000..43f0a5e --- /dev/null +++ b/pkgs/triton/language/semantic.py @@ -0,0 +1,1491 @@ +from __future__ import annotations # remove after python 3.11 + +from functools import wraps +from typing import List, Optional, Sequence, Tuple, TypeVar + +from . import core as tl +from triton._C.libtriton.triton import ir + +import triton._C.libtriton.triton as _triton +import re + +T = TypeVar('T') + +# Create custom exception that prints message "hello" + + +class IncompatibleTypeErrorImpl(Exception): + def __init__(self, type_a, type_b): + self.type_a = type_a + self.type_b = type_b + self.message = "invalid operands of type " + self.type_a.__repr__() + " and " + self.type_b.__repr__() + super(IncompatibleTypeErrorImpl, self).__init__(self.message) + + +# ===----------------------------------------------------------------------===## +# Programming Model +# ===----------------------------------------------------------------------===## + +def program_id(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_program_id(axis), tl.int32) + + +def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + +# ===----------------------------------------------------------------------===// +# Implicit Casting Utilities +# ===----------------------------------------------------------------------===// + + +def integer_promote_impl(a_ty: tl.dtype, b_ty: tl.dtype) -> tl.dtype: + a_rank = a_ty.int_bitwidth + b_rank = b_ty.int_bitwidth + a_sn = a_ty.int_signedness + b_sn = b_ty.int_signedness + # Rules for signedness taken from "Usual arithmetic conversions" on + # https://en.cppreference.com/w/c/language/conversion. + if a_sn == b_sn: + return a_ty if a_rank > b_rank else b_ty + elif a_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return a_ty if a_rank >= b_rank else b_ty + elif b_sn == tl.dtype.SIGNEDNESS.UNSIGNED: + return b_ty if b_rank >= a_rank else a_ty + assert False + + +def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> tl.dtype: + # 1) if one operand is double, the other is implicitly + # converted to double + if a_ty.is_fp64() or b_ty.is_fp64(): + return tl.float64 + # 2) if one operand is float, the other is implicitly + # converted to float + if a_ty.is_fp32() or b_ty.is_fp32(): + return tl.float32 + # 3 ) if one operand is half, the other is implicitly converted to half + # unless we're doing / or %, which do not exist natively in PTX for fp16. + # Supported PTX op: add, sub, mul, fma, neg, abs, min, max, tanh, ex2, setp + if a_ty.is_fp16() or b_ty.is_fp16(): + if div_or_mod: + return tl.float32 + else: + return tl.float16 + # 4) return bf16 only if both operands are of bf16 + if a_ty.is_bf16() or b_ty.is_bf16(): + if div_or_mod: + return tl.float32 + if a_ty.is_bf16() and b_ty.is_bf16(): + return tl.bfloat16 + return tl.float32 + if not a_ty.is_int() or not b_ty.is_int(): + assert False + # 5 ) both operands are integer and undergo + # integer promotion + if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: + raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + return integer_promote_impl(a_ty, b_ty) + +# ===----------------------------------------------------------------------===// +# Binary Operators +# ===----------------------------------------------------------------------===// + + +def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) -> None: + if type_a.is_ptr(): + if not allow_ptr_a: + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + U* with T != U + if type_b.is_ptr() and (type_a != type_b): + raise IncompatibleTypeErrorImpl(type_a, type_b) + # T* + float + if type_b.is_floating(): + raise IncompatibleTypeErrorImpl(type_a, type_b) + + +def binary_op_type_checking_impl(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder, + allow_lhs_ptr=False, allow_rhs_ptr=False, + arithmetic_check=True, div_or_mod=False + ) -> Tuple[tl.tensor, tl.tensor]: + # implicit broadcasting + lhs, rhs = broadcast_impl_value(lhs, rhs, builder) + # implicit typecasting + lhs_sca_ty = lhs.type.scalar + rhs_sca_ty = rhs.type.scalar + check_ptr_type_impl(lhs_sca_ty, rhs_sca_ty, allow_lhs_ptr) + check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr) + if arithmetic_check and not lhs_sca_ty.is_ptr() and not rhs_sca_ty.is_ptr(): + ret_sca_ty = computation_type_impl(lhs_sca_ty, rhs_sca_ty, div_or_mod) + lhs = cast(lhs, ret_sca_ty, builder) + rhs = cast(rhs, ret_sca_ty, builder) + return lhs, rhs + + +def add(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + + # offset + ptr + # ptr + offset + if other_scalar_ty.is_ptr() and not input_scalar_ty.is_ptr(): + input, other = other, input + if input_scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, other.handle), input.type) + # float + float + elif input_scalar_ty.is_floating(): + return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type) + # int + int + elif input_scalar_ty.is_int(): + return tl.tensor(builder.create_add(input.handle, other.handle), input.type) + assert False + + +def sub(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, True, False) + scalar_ty = input.type.scalar + # ptr - offset + if scalar_ty.is_ptr(): + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), + input.type) + # float - float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) + # int - int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_sub(input.handle, other.handle), input.type) + assert False + + +def mul(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float * float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fmul(input.handle, other.handle), input.type) + # * int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_mul(input.handle, other.handle), input.type) + assert False + + +def truediv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float / int + if input_scalar_ty.is_floating() and other_scalar_ty.is_int(): + other = cast(other, input_scalar_ty, builder) + # int / float + elif input_scalar_ty.is_int() and other_scalar_ty.is_floating(): + input = cast(input, other_scalar_ty, builder) + # int / int (cast to tl.float32) + elif input_scalar_ty.is_int() and other_scalar_ty.is_int(): + input = cast(input, tl.float32, builder) + other = cast(other, tl.float32, builder) + # float / float (cast to highest exponent type) + elif input_scalar_ty.is_floating() and other_scalar_ty.is_floating(): + if input_scalar_ty.fp_mantissa_width > other_scalar_ty.fp_mantissa_width: + other = cast(other, input_scalar_ty, builder) + else: + input = cast(input, other_scalar_ty, builder) + # unreachable + else: + assert False + return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) + + +def floordiv(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if input_scalar_ty.is_int() and other_scalar_ty.is_int(): + ret_ty = integer_promote_impl(input_scalar_ty, other_scalar_ty) + input = cast(input, ret_ty, builder) + other = cast(other, ret_ty, builder) + if ret_ty.is_int_signed(): + return tl.tensor(builder.create_sdiv(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_udiv(input.handle, other.handle), input.type) + assert False + + +def fdiv(input: tl.tensor, + other: tl.tensor, + ieee_rounding: bool, + builder: ir.builder) -> tl.tensor: + input_scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): + raise ValueError("both operands of fdiv must have floating scalar type") + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False, True) + ret = builder.create_fdiv(input.handle, other.handle) + return tl.tensor(ret, input.type) + + +def mod(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) + scalar_ty = input.type.scalar + other_scalar_ty = other.type.scalar + # float % float + if scalar_ty.is_floating(): + # input - input.div(other, rounding_mode="floor") * other + ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), + other, builder), + builder) + return ret + # % int + elif scalar_ty.is_int(): + if scalar_ty.int_signedness != other_scalar_ty.int_signedness: + raise ValueError("Cannot mod " + scalar_ty.__repr__() + " by " + other_scalar_ty.__repr__() + " " + "because they have different signedness;" + "this is unlikely to result in a useful answer. Cast them to the same signedness.") + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_srem(input.handle, other.handle), input.type) + else: + return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) + assert False + +############## +# bitwise ops +############## + + +def bitwise_op_type_checking_impl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: + input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) + input_sca_ty = input.type.scalar + other_sca_ty = other.type.scalar + if not input_sca_ty.is_int() or not other_sca_ty.is_int(): + raise IncompatibleTypeErrorImpl(input_sca_ty, other_sca_ty) + ret_sca_ty = integer_promote_impl(input_sca_ty, other_sca_ty) + if ret_sca_ty != input_sca_ty: + input = cast(input, ret_sca_ty, builder) + if ret_sca_ty != other_sca_ty: + other = cast(other, ret_sca_ty, builder) + return input, other + + +def and_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_and(input.handle, other.handle), input.type) + + +def or_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_or(input.handle, other.handle), input.type) + + +def xor_(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) + + +def logical_and(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return and_(input, other, builder) + + +def logical_or(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + if not other.type.is_int1(): + other = bitcast(other, tl.dtype("int1"), builder) + return or_(input, other, builder) + + +def not_(input: tl.tensor, builder: ir.builder): + if not input.type.is_int1(): + input = bitcast(input, tl.dtype("int1"), builder) + return invert(input, builder) + + +def lshr(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) + + +def ashr(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) + + +def shl(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = bitwise_op_type_checking_impl(input, other, builder) + return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + +# ===----------------------------------------------------------------------===// +# Unary Operators +# ===----------------------------------------------------------------------===// + + +def plus(input: tl.tensor) -> tl.tensor: + return input + + +def minus(input: tl.tensor, + builder: ir.builder) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr(): + raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") + _0 = tl.tensor(builder.get_null_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return sub(_0, input, builder) + + +def invert(input: tl.tensor, + builder: tl.tensor) -> tl.tensor: + input_sca_ty = input.type.scalar + if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): + raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") + _1 = tl.tensor(builder.get_all_ones_value(input_sca_ty.to_ir(builder)), input_sca_ty) + return xor_(input, _1, builder) + + +# ===----------------------------------------------------------------------===// +# Comparison Operators +# ===----------------------------------------------------------------------===// +def _bool_like(v: tl.tensor) -> tl.block_type: + if not v.type.is_block(): + return tl.int1 + shape = v.type.shape + return tl.block_type(tl.int1, shape) + + +def greater_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float > float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGT(input.handle, other.handle), _bool_like(input)) + # > int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGT(input.handle, other.handle), _bool_like(input)) + assert False + + +def greater_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float >= float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOGE(input.handle, other.handle), _bool_like(input)) + # >= int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSGE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpUGE(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_than(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLT(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLT(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULT(input.handle, other.handle), _bool_like(input)) + assert False + + +def less_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float < float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOLE(input.handle, other.handle), _bool_like(input)) + # < int + elif scalar_ty.is_int(): + if scalar_ty.is_int_signed(): + return tl.tensor(builder.create_icmpSLE(input.handle, other.handle), _bool_like(input)) + else: + return tl.tensor(builder.create_icmpULE(input.handle, other.handle), _bool_like(input)) + assert False + + +def equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpOEQ(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpEQ(input.handle, other.handle), _bool_like(input)) + assert False + + +def not_equal(input: tl.tensor, + other: tl.tensor, + builder: ir.builder) -> tl.tensor: + input, other = binary_op_type_checking_impl(input, other, builder) + scalar_ty = input.type.scalar + # float == float + if scalar_ty.is_floating(): + return tl.tensor(builder.create_fcmpUNE(input.handle, other.handle), _bool_like(input)) + # == int + elif scalar_ty.is_int(): + return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) + assert False + +# ===----------------------------------------------------------------------===// +# Block Creation +# ===----------------------------------------------------------------------===// + + +def arange(start: int, end: int, builder: ir.builder) -> tl.tensor: + if not isinstance(start, int) or not isinstance(end, int): + raise ValueError("arange's arguments must be of type tl.constexpr") + is_start_int64 = bool(start >> 32) + is_end_int64 = bool(end >> 32) + if is_start_int64 or is_end_int64: + raise ValueError("arange must fit in int32") + if end <= start: + raise ValueError("arange's end argument must be greater than the start argument") + + shape = [end - start] + ret_ty = tl.block_type(tl.int32, shape) + return tl.tensor(builder.create_make_range(start, end), ret_ty) + + +def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + if isinstance(value, tl.tensor): + assert value.numel.value == 1, "only accepts size-1 tensor" + value = cast(value, dtype, builder) + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + else: + # scalar + if value == 0: + value = builder.get_null_value(dtype.to_ir(builder)) + else: + get_value_fn = getattr(builder, f"get_{dtype.name}") + value = get_value_fn(value) + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") + ret_ty = tl.block_type(dtype, shape) + return tl.tensor(builder.create_splat(value, shape), ret_ty) + + + +def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + _1 = builder.get_one_value(dtype.to_ir(builder)) + ret_ty = tl.block_type(dtype, shape) + return tl.tensor(builder.create_splat(_1, shape), ret_ty) + +# ===----------------------------------------------------------------------===// +# Shape Manipulation +# ===----------------------------------------------------------------------===// + + +def view(input: tl.tensor, + dst_shape: List[int], + builder: ir.builder) -> tl.tensor: + # TODO: disable when TritonToTritonGPU handles views properly + + # assert len(input.shape) == len(dst_shape) + numel = 1 + for s in dst_shape: + numel *= s + if input.type.numel != numel: + raise ValueError("cannot view block of different shape") + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) + + +def reshape(input: tl.tensor, + dst_shape: List[int], + builder: ir.builder) -> tl.tensor: + raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. " + "Note that view may reorder elements in an implementation- and context- dependent way.") + + +def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: + dst_shape = list(input.type.shape) + dst_shape.insert(axis, 1) + ret_ty = tl.block_type(input.type.scalar, dst_shape) + return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) + + +def cat(lhs: tl.tensor, rhs: tl.tensor, can_reorder: bool, builder: ir.builder) -> tl.tensor: + assert can_reorder, "current implementation of `cat` always may reorder elements" + assert len(lhs.shape) == 1 + ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) + return tl.tensor(builder.create_cat(lhs.handle, rhs.handle), ret_type) + + +def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor: + if len(input.shape) != 2: + raise ValueError("Only 2D tensors can be transposed") + ret_type = tl.block_type(input.type.scalar, [input.shape[1], input.shape[0]]) + return tl.tensor(builder.create_trans(input.handle), ret_type) + + +def broadcast_impl_shape(input: tl.tensor, + shape: List[int], + builder: ir.builder) -> tl.tensor: + if not input.type.is_block(): + ret_ty = tl.block_type(input.type, shape) + return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) + src_shape = input.type.get_block_shapes() + if len(src_shape) != len(shape): + raise ValueError(f"Cannot broadcast, rank mismatch: {src_shape}, {shape}") + if shape == src_shape: + return input + for i, item in enumerate(src_shape): + if shape[i] != item and item != 1: + raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})" + f" must match the existing size ({item}) at non-singleton dimension" + f" {i}: {src_shape}, {shape}") + ret_ty = tl.block_type(input.type.scalar, shape) + return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) + + +def broadcast_impl_value(lhs: tl.tensor, + rhs: tl.tensor, + builder: ir.builder) -> tl.tensor: + lhs_ty = lhs.type + rhs_ty = rhs.type + + # make_shape_compatible(block, scalar) + if lhs_ty.is_block() and not rhs_ty.is_block(): + rhs_ty = tl.block_type(rhs_ty.scalar, lhs_ty.shape) + rhs = tl.tensor(builder.create_splat(rhs.handle, lhs_ty.get_block_shapes()), rhs_ty) + # make_shape_compatible(scalar, block) + elif not lhs_ty.is_block() and rhs_ty.is_block(): + lhs_ty = tl.block_type(lhs_ty.scalar, rhs_ty.shape) + lhs = tl.tensor(builder.create_splat(lhs.handle, rhs_ty.get_block_shapes()), lhs_ty) + # make_shape_compatible(block, block) + elif lhs_ty.is_block() and rhs_ty.is_block(): + lhs_shape = lhs_ty.get_block_shapes() + rhs_shape = rhs_ty.get_block_shapes() + + if len(lhs_shape) < len(rhs_shape): + # Add new axes to lhs + for dim in range(len(lhs_shape), len(rhs_shape)): + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs_ty = lhs.type + lhs_shape = lhs_ty.get_block_shapes() + elif len(rhs_shape) < len(lhs_shape): + # Add new axes to rhs + for dim in range(len(rhs_shape), len(lhs_shape)): + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs_ty = rhs.type + rhs_shape = rhs_ty.get_block_shapes() + assert len(rhs_shape) == len(lhs_shape) + + ret_shape = [] + for i, left in enumerate(lhs_shape): + right = rhs_shape[i] + if left == 1: + ret_shape.append(right) + elif right == 1: + ret_shape.append(left) + elif left == right: + ret_shape.append(left) + else: + raise ValueError("Cannot make_shape_compatible: incompatible dimensions " + "at index " + str(i) + ": " + str(left) + " and " + str(right)) + if lhs_shape != ret_shape: + ret_ty = tl.block_type(lhs_ty.scalar, ret_shape) + lhs = tl.tensor(builder.create_broadcast(lhs.handle, ret_shape), ret_ty) + if rhs_shape != ret_shape: + ret_ty = tl.block_type(rhs_ty.scalar, ret_shape) + rhs = tl.tensor(builder.create_broadcast(rhs.handle, ret_shape), ret_ty) + # (scalar, scalar) => returns original blocks + return lhs, rhs + +####### +# cast +####### + + +def bitcast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + if src_sca_ty.is_ptr() or dst_sca_ty.is_ptr(): + return cast(input, dst_ty, builder) + # Bitcast + src_bits = src_sca_ty.primitive_bitwidth + dst_bits = dst_sca_ty.primitive_bitwidth + if src_bits != dst_bits: + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " + "data-type of size " + str(dst_bits)) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), + dst_ty) + + +def cast(input: tl.tensor, + dst_ty: tl.dtype, + builder: ir.builder) -> tl.tensor: + src_ty = input.type + if isinstance(dst_ty, tl.constexpr): + dst_ty = dst_ty.value + if src_ty.is_block(): + dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) + if src_ty == dst_ty: + return input + + src_sca_ty = src_ty.scalar + dst_sca_ty = dst_ty.scalar + + # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 + if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ + (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()): + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), + dst_ty) + + # bf16 <=> (not fp32) + if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ + (src_sca_ty.is_bf16() and not dst_sca_ty.is_fp32()): + return cast(cast(input, tl.float32, builder), dst_sca_ty, builder) + + # Standard floating types' casting: truncation + # fp64 => fp32, fp16, bf16 + # fp32 => fp16, bf16 + truncate_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth + if truncate_fp: + return tl.tensor(builder.create_fp_trunc(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Standard floating types' casting: extension + # fp32 => fp64 + # fp16 => fp32, fp64 + # bf16 => fp32, fp64 + ext_fp = src_sca_ty.is_floating() and \ + dst_sca_ty.is_floating() and \ + src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth + if ext_fp: + return tl.tensor(builder.create_fp_ext(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Casting between integer types + if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ + (src_sca_ty.int_bitwidth != dst_sca_ty.int_bitwidth or src_sca_ty.int_signedness != dst_sca_ty.int_signedness): + sign_extend = src_sca_ty.is_int_signed() and not src_sca_ty.is_bool() + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + else: + return tl.tensor(builder.create_int_cast(input.handle, + dst_ty.to_ir(builder), sign_extend), + dst_ty) + + # Casting standard floating types to integer types + if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): + if dst_sca_ty.is_bool(): + ty = input.dtype.to_ir(builder) + _0 = tl.tensor(builder.get_null_value(ty), input.dtype) + return not_equal(input, _0, builder) + elif dst_sca_ty.is_int_signed(): + return tl.tensor(builder.create_fp_to_si(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_fp_to_ui(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Casting integer types to standard floating types + if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): + if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): + return tl.tensor(builder.create_ui_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + else: + return tl.tensor(builder.create_si_to_fp(input.handle, + dst_ty.to_ir(builder)), + dst_ty) + + # Casting pointer types to integer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): + bitwidth = dst_sca_ty.int_bitwidth + if bitwidth == 64: + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), + dst_ty) + if bitwidth == 1: + return not_equal(cast(input, tl.int64, builder), + tl.tensor(builder.get_int64(0), tl.int64), + builder) + + # Casting integer types to pointer types + if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_int_to_ptr(input.handle, dst_ty.to_ir(builder)), dst_ty) + + # Casting pointer types to pointer types + if src_sca_ty.is_ptr() and dst_sca_ty.is_ptr(): + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) + + assert False, f'cannot cast {input} to {dst_ty}' + +# ===----------------------------------------------------------------------===// +# Memory Operators +# ===----------------------------------------------------------------------===// + + +def _str_to_cache_modifier(cache_modifier): + cache = ir.CACHE_MODIFIER.NONE # default + if cache_modifier: + if cache_modifier == ".ca": + cache = ir.CACHE_MODIFIER.CA + elif cache_modifier == ".cg": + cache = ir.CACHE_MODIFIER.CG + else: + raise ValueError(f"Cache modifier {cache_modifier} not supported") + return cache + + +def _str_to_eviction_policy(eviction_policy): + eviction = ir.EVICTION_POLICY.NORMAL # default + if eviction_policy: + if eviction_policy == "evict_last": + eviction = ir.EVICTION_POLICY.EVICT_LAST + elif eviction_policy == "evict_first": + eviction = ir.EVICTION_POLICY.EVICT_FIRST + else: + raise ValueError(f"Eviction policy {eviction_policy} not supported") + return eviction + + +def _str_to_padding_option(padding_option): + padding = None # default + if padding_option: + if padding_option == "zero": + padding = ir.PADDING_OPTION.PAD_ZERO + elif padding_option == "nan": + padding = ir.PADDING_OPTION.PAD_NAN + else: + raise ValueError(f"Padding option {padding_option} not supported") + return padding + + +def _canonicalize_boundary_check(boundary_check, block_shape): + if boundary_check: + if not hasattr(boundary_check, "__iter__"): + boundary_check = [boundary_check] + boundary_check = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in boundary_check] + for dim in boundary_check: + assert isinstance(dim, int) and 0 <= dim < len(block_shape) + assert len(boundary_check) > 0 + assert len(boundary_check) == len(set(boundary_check)), "Duplicate dimension in `boundary_check`" + return sorted(boundary_check) + return tuple() + + +def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a block pointer: `pointer_type>` + # Block pointer can not have `mask` and `other` arguments + if mask or other: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + if elt_ty.is_int() and padding == ir.PADDING_OPTION.PAD_NAN: + raise ValueError("Padding option `nan` is not supported for integer block pointers") + + # `dst_ty` is de-referenced type of the pointer type + dst_ty = ptr.type.element_ty + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, + is_volatile), dst_ty) + + +def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.load`") + + # Check `mask`, `other`, `boundary_check`, and `padding` arguments + if not mask and other: + raise ValueError("`other` cannot be provided without `mask`") + if padding or boundary_check: + raise ValueError("`padding_option` or `boundary_check` argument is not supported for loading a tensor of" + "pointers or loading a scalar. Because the compiler does not know the boundary; please " + "use block pointers (defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `mask` and `other` + if not ptr.type.is_block(): + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + if other and other.type.is_block(): + raise ValueError("Other argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `other` into the same shape as `ptr` + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if other: + other = broadcast_impl_shape(other, ptr.type.get_block_shapes(), builder) + + # Get `pointer_type` and `elt_ty` + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast `other` into `ele_ty` type + if other: + other = cast(other, elt_ty, builder) + + # Create loaded result type `dst_ty` + if ptr.type.is_block(): + shape = ptr.type.get_block_shapes() + dst_ty = tl.block_type(elt_ty, shape) + else: + # Load by de-referencing the pointer of scalar + dst_ty = elt_ty + + # Build IR + if not mask: + return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) + else: + return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, + eviction, is_volatile), dst_ty) + + +def load(ptr: tl.tensor, + mask: Optional[tl.tensor], + other: Optional[tl.tensor], + boundary_check, + padding_option: str, + cache_modifier: str, + eviction_policy: str, + is_volatile: bool, + builder: ir.builder) -> tl.tensor: + # Cache, eviction and padding options + cache = _str_to_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + padding = _str_to_padding_option(padding_option) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Load by a block pointer: `pointer_type>` + return _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + else: + # Load by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder) + + +def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a block pointer: `pointer_type>` + # Block pointers can not have the `mask` argument + if mask: + raise ValueError("`mask` and `other` arguments cannot be specified for loading block pointers") + + # Check same shape and element type + block_shape = ptr.type.element_ty.get_block_shapes() + if not val.type.is_block(): + val = broadcast_impl_shape(val, block_shape, builder) + assert val.type.is_block(), "Value argument must be block type or a scalar" + assert block_shape == val.type.get_block_shapes(), "Block shape and value shape mismatch" + assert ptr.type.element_ty.element_ty == val.type.element_ty, "Block element type and value element type mismatch" + + elt_ty = ptr.type.element_ty.element_ty + assert elt_ty != tl.int1, "`tl.int1` should be rewrited in `tl.make_block_ptr`" + + # Check `boundary_check` argument + boundary_check = _canonicalize_boundary_check(boundary_check, block_shape) + + # Build IR + return tl.tensor(builder.create_tensor_pointer_store(ptr.handle, val.handle, boundary_check, cache, eviction), + tl.void) + + +def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + if not ptr.type.scalar.is_ptr(): + raise ValueError(f"Unsupported ptr type {ptr.type.__repr__()} in `tl.store`") + + # Check `boundary_check` argument + if boundary_check: + raise ValueError("`boundary_check` argument is not supported for storing a tensor of pointers or storing a " + "scalar. Because the compiler does not know the boundary; please use block pointers " + "(defined by `make_block_ptr`) instead") + + # For a pointer of scalar, check the type of `val` and `mask` + if not ptr.type.is_block(): + if val.type.is_block(): + raise ValueError("Value argument cannot be block type if pointer argument is not a block") + if mask and mask.type.is_block(): + raise ValueError("Mask argument cannot be block type if pointer argument is not a block") + + # Make `mask` and `val` into the same shape as `ptr` + if ptr.type.is_block(): + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + + ptr_ty = ptr.type.scalar + elt_ty = ptr_ty.element_ty + + # Treat `pointer_type` as `pointer_type` + if elt_ty == tl.int1: + elt_ty = tl.int8 + ptr_ty = tl.pointer_type(elt_ty, ptr_ty.address_space) + ptr = cast(ptr, ptr_ty, builder) + + # Cast to target data type + val = cast(val, elt_ty, builder) + + # Build IR + if not mask: + return tl.tensor(builder.create_store(ptr.handle, val.handle, cache, eviction), tl.void) + if not mask.type.scalar.is_bool(): + raise ValueError("Mask must have boolean scalar type") + return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) + + +def store(ptr: tl.tensor, + val: tl.tensor, + mask: Optional[tl.tensor], + boundary_check, + cache_modifier: str, + eviction_policy: str, + builder: ir.builder) -> tl.tensor: + # Cache and eviction options + cache = _str_to_cache_modifier(cache_modifier) + eviction = _str_to_eviction_policy(eviction_policy) + + if ptr.type.is_ptr() and ptr.type.element_ty.is_block(): + # Store by a block pointer: `pointer_type>` + return _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builder) + else: + # Store by a tensor of pointers or a pointer of scalar: `block_type>` or `pointer_type<>` + return _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder) + + +######### +# atomic +######### + + +def atomic_cas(ptr: tl.tensor, + cmp: tl.tensor, + val: tl.tensor, + builder: ir.builder) -> tl.tensor: + element_ty = ptr.type.scalar.element_ty + if element_ty.primitive_bitwidth not in [16, 32, 64]: + raise ValueError("atomic_cas only supports elements with width {16, 32, 64}") + return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle), val.type) + + +def atom_red_typechecking_impl(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + op: str, + builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: + if not ptr.type.scalar.is_ptr(): + raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) + + element_ty = ptr.type.scalar.element_ty + if element_ty is tl.float16 and op != 'add': + raise ValueError("atomic_" + op + " does not support fp16") + if element_ty in [tl.int1, tl.int8, tl.int16]: + raise ValueError("atomic_" + op + " does not support " + str(element_ty)) + if ptr.type.is_block(): + if mask: + mask = broadcast_impl_shape(mask, ptr.type.get_block_shapes(), builder) + if val: + val = broadcast_impl_shape(val, ptr.type.get_block_shapes(), builder) + val = cast(val, ptr.type.scalar.element_ty, builder) + if not mask: + mask_ir = builder.get_int1(True) + mask_ty = tl.int1 + if ptr.type.is_block(): + mask_ir = builder.create_splat(mask_ir, ptr.type.get_block_shapes()) + mask_ty = tl.block_type(tl.int1, ptr.type.get_block_shapes()) + mask = tl.tensor(mask_ir, mask_ty) + return ptr, val, mask + + +def atomic_max(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) + sca_ty = val.type.scalar + # direct call to atomic_max for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + ptr.handle, + val.handle, + mask.handle), + val.type) + # ROCM TODO: implement atomic_max/min for f32 as they are supported by MI cards. + # for float + # return atomic_smax(i_ptr, i_val) if val >= 0 + # return atomic_umin(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle), i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle), i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_min(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) + sca_ty = val.type.scalar + # direct call to atomic_min for integers + if sca_ty.is_int(): + if sca_ty.is_int_signed(): + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + else: + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, + ptr.handle, + val.handle, + mask.handle), + val.type) + # for float + # return atomic_smin(i_ptr, i_val) if val >= 0 + # return atomic_umax(i_ptr, i_val) if val < 0 + i_val = bitcast(val, tl.int32, builder) + i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) + pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, + i_ptr.handle, + i_val.handle, + and_(mask, pos, builder).handle), + i_val.type) + neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, + i_ptr.handle, + i_val.handle, + and_(mask, neg, builder).handle), + i_val.type) + return where(pos, pos_ret, neg_ret, builder) + + +def atomic_add(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) + sca_ty = val.type.scalar + op = ir.ATOMIC_OP.FADD if sca_ty.is_floating() else ir.ATOMIC_OP.ADD + return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_and(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_or(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xor(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle), val.type) + + +def atomic_xchg(ptr: tl.tensor, + val: tl.tensor, + mask: tl.tensor, + builder: ir.builder) -> tl.tensor: + ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle), val.type) + +# ===----------------------------------------------------------------------===// +# Linear Algebra +# ===----------------------------------------------------------------------===// + +def is_hip(): + try: + import torch + except ImportError: + raise ImportError("Triton requires PyTorch to be installed") + return torch.version.hip is not None + +def gpu_has_mfma() -> bool: + if not is_hip(): + return False + arch_info = _triton.get_arch_info() + gfx_arch_details = re.search('amd.*', arch_info) + if gfx_arch_details is None: + return False + gfx_arch_details = gfx_arch_details.group(0).strip().split('--') + return gfx_arch_details[1].split(':')[0] in ['gfx908', 'gfx90a'] + +def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: + if not gpu_has_mfma(): + return False + # TODO: Add check for configurations and types. + return True + +def dot(lhs: tl.tensor, + rhs: tl.tensor, + allow_tf32: bool, + out_dtype: tl.dtype, + builder: ir.builder) -> tl.tensor: + assert lhs.type.is_block() and rhs.type.is_block() + assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!" + assert len(lhs.shape) == 2 and len(rhs.shape) == 2 + assert lhs.shape[1].value == rhs.shape[0].value + assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \ + and rhs.shape[1].value >= 16,\ + "small blocks not supported!" + if lhs.type.scalar.is_int(): + assert lhs.type.scalar == tl.int8, "only int8 supported!" + # TODO: This is CUDA specific, check if ROCm has the same limitation + assert lhs.shape[1].value >= 32, "small blocks not supported!" + _0 = builder.get_int32(0) + ret_scalar_ty = tl.int32 + elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): + _0 = builder.get_fp32(0) + ret_scalar_ty = tl.float32 + else: + _0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0) + ret_scalar_ty = out_dtype + + M = lhs.type.shape[0] + N = rhs.type.shape[1] + + # Cast operands of types f16 and i8 for configurations where FMA only supported. + if is_hip() and not mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty): + ret_cast_scalar_ty = tl.float32 if lhs.type.scalar.is_int() else ret_scalar_ty + lhs = cast(lhs, ret_cast_scalar_ty, builder) + rhs = cast(rhs, ret_cast_scalar_ty, builder) + _0 = builder.create_splat(builder.get_fp32(0), [M, N]) + ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + return cast(ret, ret_scalar_ty, builder) + + _0 = builder.create_splat(_0, [M, N]) + ret_ty = tl.block_type(ret_scalar_ty, [M, N]) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + ret_ty) + + +# ===----------------------------------------------------------------------===// +# Indexing +# ===----------------------------------------------------------------------===// + +def where(condition: tl.tensor, + x: tl.tensor, + y: tl.tensor, + builder: ir.builder) -> tl.tensor: + condition = cast(condition, tl.int1, builder) + if condition.type.is_block(): + condition, x = broadcast_impl_value(condition, x, builder) + x, y = broadcast_impl_value(x, y, builder) + condition, x = broadcast_impl_value(condition, x, builder) + + x, y = binary_op_type_checking_impl(x, y, builder, True, True) + if not condition.type.is_block(): + condition, _ = broadcast_impl_value(condition, x, builder) + ret_ty = x.type + return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + +# ===----------------------------------------------------------------------===// +# Reduction +# ===----------------------------------------------------------------------=== + + +def reduction( + inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder +) -> Tuple[tl.tensor, ...]: + if axis is None: + new_inputs = [] + for i in range(len(inputs)): + new_shape = [inputs[i].numel.value] + new_inputs.append(view(inputs[i], new_shape, builder)) + inputs = tuple(new_inputs) + axis = 0 + # get result shape + shape = inputs[0].type.shape + ret_shape = [s for i, s in enumerate(shape) if i != axis] + for t in inputs: + assert t.type.shape == shape + + def wrap_tensor(x, scalar_ty): + if ret_shape: + res_ty = tl.block_type(scalar_ty, ret_shape) + else: + # 0d-tensor -> scalar + res_ty = scalar_ty + return tl.tensor(x, res_ty) + + reduce_op = builder.create_reduce([t.handle for t in inputs], axis) + region_builder_fn(reduce_op) + reduce_op.verify() + + return tuple( + wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) + for i in range(len(inputs)) + ) + + +# ===----------------------------------------------------------------------=== +# Math +# ===----------------------------------------------------------------------=== + +def _check_dtype(dtypes: List[str]) -> T: + """ + We following libdevice's convention to check accepted data types for math functions. + It is not a good practice to support all data types as accelerators/GPUs don't support + many float16 and bfloat16 math operations. + We should let the users know that they are using and invoke explicit cast to convert + the data type to the supported one. + """ + def wrapper(fn): + @wraps(fn) + def check(*args, **kwargs): + # concatenate args and kwargs + all_args = list(args) + list(kwargs.values()) + for arg in [a for a in all_args if isinstance(a, tl.tensor)]: + if arg.type.scalar.name not in dtypes: + raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") + return fn(*args, **kwargs) + return check + + return wrapper + + +def umulhi(x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: + x, y = binary_op_type_checking_impl(x, y, builder) + # FIXME(Keren): not portable, should be fixed + from . import math + return math.mulhi(x, y, _builder=builder) + + +@_check_dtype(dtypes=["fp32", "fp64"]) +def floor(x: tl.tensor, builder: ir.builder) -> tl.tensor: + # FIXME(Keren): not portable, should be fixed + from . import math + return math.floor(x, _builder=builder) + + +@_check_dtype(dtypes=["fp32", "fp64"]) +def exp(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_exp(x.handle), x.type) + + +@_check_dtype(dtypes=["fp32", "fp64"]) +def log(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_log(x.handle), x.type) + + +@_check_dtype(dtypes=["fp32", "fp64"]) +def cos(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_cos(x.handle), x.type) + + +@_check_dtype(dtypes=["fp32", "fp64"]) +def sin(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sin(x.handle), x.type) + + +@_check_dtype(dtypes=["fp32", "fp64"]) +def sqrt(x: tl.tensor, builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_sqrt(x.handle), x.type) + + +def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor: + dtype = x.dtype + if dtype.is_floating(): + return tl.tensor(builder.create_fabs(x.handle), x.type) + elif dtype.is_int_signed(): + return tl.tensor(builder.create_iabs(x.handle), x.type) + elif dtype.is_int_unsigned(): + return x # no-op + else: + assert False, f"Unexpected dtype {dtype}" + + +## + + +def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to multiple_of does not match the length of values") + x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) + return x + + +def max_contiguous(x: tl.tensor, values: List[int]) -> tl.tensor: + if len(x.shape) != len(values): + raise ValueError("Shape of input to max_contiguous does not match the length of values") + x.handle.set_attr("tt.contiguity", ir.make_attr(values, x.handle.get_context())) + return x + + +def debug_barrier(builder: ir.builder) -> tl.tensor: + return tl.tensor(builder.create_barrier(), tl.void) + + +def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor: + new_args = [] + for arg in args: + new_args.append(arg.handle) + return tl.tensor(builder.create_print(prefix, new_args), tl.void) + + +def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1,)) + cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty) + return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) + + +def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, tl.constexpr): + return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value) + elif isinstance(elem, tl.tensor): + assert elem.numel.value == 1, "Expected a scalar in shape/strides/offsets" + assert elem.dtype.is_int(), "Expected an integer scalar type in shape/strides/offsets" + if elem.dtype != tl.int64 and require_i64: + return builder.create_int_cast(elem.handle, builder.get_int64_ty(), elem.dtype.is_int_signed()) + elif elem.dtype != tl.int32: + return builder.create_int_cast(elem.handle, builder.get_int32_ty(), elem.dtype.is_int_signed()) + return elem.handle + assert False, f"Unsupported element type in shape/strides/offsets: {type(elem)}" + + +def _convert_to_ir_values(builder, list_like, require_i64=True): + if hasattr(list_like, "__iter__"): + return [_convert_elem_to_ir_value(builder, elem, require_i64) for elem in list_like] + return [_convert_elem_to_ir_value(builder, list_like, require_i64)] + + +def make_block_ptr(base: tl.tensor, shape, strides, offsets, block_shape, order, builder: ir.builder) -> tl.tensor: + # Convert dynamic arguments to IR values + # NOTES(Chenggang): current `shape/strides` are `int64_t`, while `offsets/block_shape` are `int32_t` + shape = _convert_to_ir_values(builder, shape) + strides = _convert_to_ir_values(builder, strides) + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Check `base` type + if not base.type.is_ptr() or base.type.element_ty.is_block(): + raise ValueError("Expected `base` to be a pointer type (but not a block pointer type or others)") + + # Treat `pointer_type` as `pointer_type` + if base.type.element_ty == tl.int1: + base = cast(base, tl.pointer_type(tl.int8, base.type.address_space), builder) + + # Check whether `block_shape` is static + if not hasattr(block_shape, "__iter__"): + block_shape = [block_shape] + block_shape = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in block_shape] + assert all([isinstance(elem, int) and -2**31 <= elem < 2**31 for elem in block_shape]), \ + "Expected a list of constant integers (`int32_t` range) in `block_shape`" + + # Check `order` + if not hasattr(order, "__iter__"): + order = [order] + order = [elem.value if isinstance(elem, tl.constexpr) else elem for elem in order] + assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" + + # Must have same length + assert all([len(block_shape) == len(list_like) for list_like in [shape, strides, offsets, order]]), \ + "Expected shape/strides/offsets/block_shape to have the same length" + + # Build value, the type is: + # `pointer_type>` in Python + # `tt.ptr>` in MLIR + handle = builder.create_make_block_ptr(base.handle, shape, strides, offsets, block_shape, order) + return tl.tensor(handle, tl.pointer_type(tl.block_type(base.type.element_ty, block_shape))) + + +def advance(base: tl.tensor, offsets, builder: ir.builder) -> tl.tensor: + # Convert dynamic offsets to IR values + offsets = _convert_to_ir_values(builder, offsets, require_i64=False) + + # Advanced block pointer type is the same as before + return tl.tensor(builder.create_advance(base.handle, offsets), base.type) diff --git a/pkgs/triton/language/standard.py b/pkgs/triton/language/standard.py new file mode 100644 index 0000000..b997674 --- /dev/null +++ b/pkgs/triton/language/standard.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from ..runtime.jit import jit +from . import core + +# ----------------------- +# Standard library +# ----------------------- + + +@jit +def cdiv(x, div): + """ + Computes the ceiling division of :code:`x` by :code:`div` + + :param x: the input number + :type input: Block + :param div: the divisor + :param div: Block + """ + return (x + div - 1) // div + + +@jit +@core._add_math_1arg_docstr("sigmoid") +def sigmoid(x): + return 1 / (1 + core.exp(-x)) + + +@jit +@core._add_math_1arg_docstr("softmax") +def softmax(x, ieee_rounding=False): + z = x - core.max(x, 0) + num = core.exp(z) + den = core.sum(num, 0) + return core.fdiv(num, den, ieee_rounding) + + +@jit +def ravel(x): + """ + Returns a contiguous flattened view of :code:`x`. + + :param x: the input tensor + :type x: Block + """ + return core.view(x, [x.numel]) + + +@jit +def swizzle2d(i, j, size_i, size_j, size_g): + """ + Transforms indices of a row-major size_i*size_j matrix into those + of one where indices are row major for each group of size_j rows. + For example, for size_i = size_j = 4 and size_g = 2, it will transform + [[0 , 1 , 2 , 3 ], + [4 , 5 , 6 , 7 ], + [8 , 9 , 10, 11], + [12, 13, 14, 15]] + into + [[0, 2, 4 , 6 ], + [1, 3, 5 , 7 ], + [8, 10, 12, 14], + [9, 11, 13, 15]] + """ + # "unrolled index in array" + ij = i * size_j + j + # number of elements in `size_g` groups + # of `size_j` columns + size_gj = size_g * size_j + # index of the group in which (i,j) is + group_id = ij // size_gj + # row-index of the first element of this group + off_i = group_id * size_g + # last group may have fewer rows + size_g = core.minimum(size_i - off_i, size_g) + # new row and column indices + new_i = off_i + (ij % size_g) + new_j = (ij % size_gj) // size_g + return new_i, new_j + + +@jit +def zeros(shape, dtype): + """ + Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`. + + :param shape: Shape of the new array, e.g., (8, 16) or (8, ) + :type shape: tuple of ints + :param dtype: Data-type of the new array, e.g., :code:`tl.float16` + :type dtype: DType + """ + return core.full(shape, 0, dtype) + + +@jit +def zeros_like(input): + return zeros(input.shape, input.dtype) diff --git a/pkgs/triton/ops/__init__.py b/pkgs/triton/ops/__init__.py new file mode 100644 index 0000000..3912a5b --- /dev/null +++ b/pkgs/triton/ops/__init__.py @@ -0,0 +1,17 @@ +# from .conv import _conv, conv +from . import blocksparse +from .cross_entropy import _cross_entropy, cross_entropy +from .flash_attention import attention +from .matmul import _matmul, matmul +from .bmm_matmul import _bmm, bmm + +__all__ = [ + "blocksparse", + "_cross_entropy", + "cross_entropy", + "_matmul", + "matmul", + "_bmm", + "bmm", + "attention", +] diff --git a/pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..ab6dfc1 Binary files /dev/null and b/pkgs/triton/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc b/pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc new file mode 100644 index 0000000..45c499f Binary files /dev/null and b/pkgs/triton/ops/__pycache__/bmm_matmul.cpython-310.pyc differ diff --git a/pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc b/pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc new file mode 100644 index 0000000..f6f6f3b Binary files /dev/null and b/pkgs/triton/ops/__pycache__/cross_entropy.cpython-310.pyc differ diff --git a/pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc b/pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc new file mode 100644 index 0000000..3721df5 Binary files /dev/null and b/pkgs/triton/ops/__pycache__/flash_attention.cpython-310.pyc differ diff --git a/pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc b/pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc new file mode 100644 index 0000000..6a3e670 Binary files /dev/null and b/pkgs/triton/ops/__pycache__/matmul.cpython-310.pyc differ diff --git a/pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc b/pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc new file mode 100644 index 0000000..cdabd97 Binary files /dev/null and b/pkgs/triton/ops/__pycache__/matmul_perf_model.cpython-310.pyc differ diff --git a/pkgs/triton/ops/blocksparse/__init__.py b/pkgs/triton/ops/blocksparse/__init__.py new file mode 100644 index 0000000..6b24b53 --- /dev/null +++ b/pkgs/triton/ops/blocksparse/__init__.py @@ -0,0 +1,7 @@ +from .matmul import matmul +from .softmax import softmax + +__all__ = [ + "matmul", + "softmax", +] diff --git a/pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6ff9c1a Binary files /dev/null and b/pkgs/triton/ops/blocksparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc b/pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc new file mode 100644 index 0000000..068f11c Binary files /dev/null and b/pkgs/triton/ops/blocksparse/__pycache__/matmul.cpython-310.pyc differ diff --git a/pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc b/pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc new file mode 100644 index 0000000..a636d67 Binary files /dev/null and b/pkgs/triton/ops/blocksparse/__pycache__/softmax.cpython-310.pyc differ diff --git a/pkgs/triton/ops/blocksparse/matmul.py b/pkgs/triton/ops/blocksparse/matmul.py new file mode 100644 index 0000000..c599af2 --- /dev/null +++ b/pkgs/triton/ops/blocksparse/matmul.py @@ -0,0 +1,437 @@ +import torch + +import triton +import triton.language as tl + +# ******************************************************** +# -------------------------------------------------------- +# Sparse = Dense x Dense (SDD) +# This operation uses super-blocking to make sure that +# it's done efficiently when small blocks can be grouped +# together +# -------------------------------------------------------- +# ******************************************************** + + +@triton.heuristics({ + 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, +}) +@triton.jit +def _sdd_kernel( + A, B, C, + stride_za, stride_ha, stride_ma, stride_ak, + stride_zb, stride_hb, stride_bk, stride_nb, + stride_zc, stride_hc, stride_mc, stride_nc, + K, grid_offset, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + BLOCK: tl.constexpr, EVEN_K: tl.constexpr +): + # ------------ # + # - Prologue - # + # ------------ # + block_id = tl.program_id(0) + grid_offset + lut += block_id * 3 + # offsets + off_z = tl.program_id(2) # batch + off_h = tl.load(lut + 0) # head + + # initialize pointers to A + start_am = tl.load(lut + 1) + offs_am = start_am * BLOCK + (tl.arange(0, TILE_M) % BLOCK) + offs_ak = tl.arange(0, TILE_K) + a_ptrs = A \ + + off_z * stride_za \ + + off_h * stride_ha \ + + offs_am[:, None] * stride_ma \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B + start_bn = tl.load(lut + 2) + offs_bn = start_bn * BLOCK + (tl.arange(0, TILE_N) % BLOCK) + offs_bk = tl.arange(0, TILE_K) + b_ptrs = B \ + + off_z * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_nb \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + for k in range(K, 0, -TILE_K): + if EVEN_K: + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + else: + a = tl.load(a_ptrs, mask=offs_ak[None, :] < k, other=0.) + b = tl.load(b_ptrs, mask=offs_bk[:, None] < k, other=0.) + acc += tl.dot(a, b, out_dtype=tl.float32) + a_ptrs += TILE_K * stride_ak + b_ptrs += TILE_K * stride_bk + c = acc.to(C.dtype.element_ty) + # ---------------- # + # Epilogue # + # ---------------- # + offs_cm = tl.arange(0, TILE_M) % BLOCK + offs_cn = tl.arange(0, TILE_N) % BLOCK + pc = C \ + + off_z * stride_zc \ + + block_id * stride_hc \ + + offs_cm[:, None] * stride_mc \ + + offs_cn[None, :] * stride_nc + tl.store(pc, c, mask=True) + + +def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # (A * B)^T = B^T * A^T + if trans_c: + a, b = b, a + trans_a, trans_b = not trans_b, not trans_a + # shape constraints + a_dim = -2 if trans_a else -1 + b_dim = -1 if trans_b else -2 + Ka, Kb = a.shape[a_dim], b.shape[b_dim] + if Ka != Kb: + raise ValueError(f"Inner dimension mismatch (A: {Ka} vs B: {Kb})") + # allocate output + if out is None: + c = torch.empty((a.shape[0], lut.shape[0], block, block), dtype=a.dtype, device=a.device) + else: + assert out.shape == (a.shape[0], lut.shape[0], block, block) + c = out + grid = [c.shape[1], 1, c.shape[0]] + _sdd_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(2), c.stride(3), + Ka, 0, lut, + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, + num_warps=4, + ) + return c + + +def sdd_lut(layout, block, device): + lut = layout.nonzero(as_tuple=False).to(device).int() + lut = lut.contiguous() + return lut, None + +# ----------------------------- +# Dense = Sparse x Dense (DSD) +# This operation uses a look-up table that contains pre-computed pointer increments +# in order to minimize computations in the inner loop of the matmul kernel. +# ----------------------------- + + +@triton.jit +def _dsd_kernel( + A, B, C, + stride_az, stride_ha, stride_am, stride_ak, + stride_zb, stride_hb, stride_bk, stride_bn, + stride_zc, stride_hc, stride_cm, stride_cn, + DS0, DS1, lut, + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr +): + # ------------ # + # - Prologue - # + # ------------ # + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + num_pid_m = tl.num_programs(0) + num_pid_n = tl.num_programs(1) + pid_n, pid_m = tl.swizzle2d(pid_n, pid_m, num_pid_n, num_pid_m, GROUP_SIZE_M) + pidz = tl.program_id(2) + header = lut + pid_n * 4 + offset = tl.load(header + 0) + K = tl.load(header + 1) + column = tl.load(header + 2) + off_h = tl.load(header + 3) + pinc = lut + offset + # initialize pointers to A (sparse) + block_id = tl.load(pinc + 1) + block_id = tl.multiple_of(block_id, 8) # compiler hint + offs_am = tl.arange(0, TILE_M) + offs_ak = tl.arange(0, TILE_K) + pa = A + pidz * stride_az \ + + block_id * stride_ha \ + + offs_am[:, None] * stride_am \ + + offs_ak[None, :] * stride_ak + # initialize pointers to B (dense) + offs_bn = pid_m * TILE_N + tl.arange(0, TILE_N) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn % DS0, TILE_N), TILE_N) + start_bk = tl.load(pinc) + start_bk = tl.multiple_of(start_bk, 8) # compiler hint + offs_bk = start_bk + tl.arange(0, TILE_K) + pb = B + pidz * stride_zb \ + + off_h * stride_hb \ + + offs_bn[None, :] * stride_bn \ + + offs_bk[:, None] * stride_bk + # ---------------- # + # Inner Loop # + # ---------------- # + acc = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + for k in range(K, 0, -TILE_K): + a = tl.load(pa) + b = tl.load(pb) + acc += tl.dot(a, b, out_dtype=tl.float32) + pa += inc_a + pb += inc_b * stride_bk + pinc += 2 + inc_a = tl.load(pinc + 1) + inc_a = tl.multiple_of(inc_a, 8) + inc_b = tl.load(pinc) + inc_b = tl.multiple_of(inc_b, 8) + c = acc.to(C.dtype.element_ty) + # initialize pointers to C + offs_cm = column * TILE_M + tl.arange(0, TILE_M) + offs_cn = pid_m * TILE_N + tl.arange(0, TILE_N) + pc = C \ + + off_h * stride_hc \ + + pidz * stride_zc \ + + offs_cm[:, None] * stride_cm \ + + offs_cn[None, :] * stride_cn + tl.store(pc, c, mask=offs_cn[None, :] < DS0) + + +def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + if a.stride(2) != 1 and a.stride(3) != 1: + a = a.contiguous() + if b.stride(2) != 1 and b.stride(3) != 1: + b = b.contiguous() + # shapes / dtypes + AS1 = block * spdims[2 if trans_a else 1] + BS0 = b.size(0) + BS1 = b.size(1) + BS3 = b.size(2 if trans_b else 3) + dtype = a.dtype + # allocate output + CS0 = BS0 + CS1 = BS1 + CS2 = BS3 if trans_c else AS1 + CS3 = AS1 if trans_c else BS3 + if out is None: + c = torch.empty((CS0, CS1, CS2, CS3), dtype=dtype, device=a.device) + else: + assert out.shape == (CS0, CS1, CS2, CS3) + c = out + # meta-parameter heuristics + TILE_N = 128 + # compute output + grid = lambda meta: [triton.cdiv(BS3, meta['TILE_N']), width, BS0] + _dsd_kernel[grid]( + a, b, c, + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), + BS3, AS1, lut, + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, + num_warps=4, GROUP_SIZE_M=4, + ) + # exit() + return c + + +def dsd_lut(layout, block, step, trans, device): + """ + Generates the look-up table for incrementing pointers in the DSD/DDS matmul. + Example (BLOCK=32, STEP=16) + [[1, 0, 0, 1, 0], + [0, 1, 1, 0, 1], + [1, 0, 1, 0, 0]] + + Then the offsets for A are + [0 , 16, 32, 48] <- row 0 + \\----/ \\----/ + col=0 col=3 + [64, 80, 96, 112, 128, 144] <- row 1 + \\----/ \\----/ \\------/ + col=1 col=2 col=3 + [160, 176, 192, 208] + which leads to increments table + [0, 16, 16, 16, || 64, 16, 16, 16, 16, 16, || 160, 16, 16, 16] + + Because B is dense, the offsets are + [0, 16, 96, 112] <- row 0 + [32, 48, 64, 80] <- row 1 + [0, 16, 64, 80] <- row 2 + """ + sizes = torch.sum(layout, 2 if trans else 1) + head_id, col_id = torch.ones_like(sizes).nonzero(as_tuple=True) + sizes = sizes.flatten() + segments = sizes * step + # pointer increments + if trans: + nnz = layout.nonzero(as_tuple=False) + else: + nnz = layout.transpose(1, 2).nonzero(as_tuple=False) + num_blocks = nnz.size(0) + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + offsets = torch.min(offsets, (num_blocks - 1) * torch.ones_like(offsets)) + # ------------------------------- + # dense input pointer increments + # ------------------------------- + # Note that the inner loop matmul kernel may have a fixed step size (e.g., TILE_K) + # that is smaller than the block size, so we need to do a bit of extra work + # to handle this case + B_idx = nnz[:, 2] * block + B_incs = B_idx.clone() + B_incs[1:] -= B_idx[:-1] + div = block // step + B_incs = B_incs.view(-1, 1).repeat(1, div) + B_incs[:, 1:] = step + B_incs[:, 0] -= (div - 1) * step + # first increment for each reduction is actually the offset + B_incs[offsets[segments > 0], 0] = B_idx[offsets[segments > 0]] + B_incs = B_incs.view(-1) + # ------------------------------- + # sparse input pointer increments + # ------------------------------- + # same as above, except that the increments are in the sparse memory layout + if trans: + A_idx = torch.arange(num_blocks, device=layout.device) + else: + A_idx = torch.tensor([], dtype=torch.int64, device=layout.device) + current_offset = 0 + for z in range(layout.size(0)): + layoutw = layout[z, :, :].clone().long() + msum = layoutw.sum() + layoutw[layoutw > 0] = 1 + torch.arange(msum, device=layout.device) + A_idx = torch.cat((A_idx, current_offset + layoutw.T[layoutw.T > 0] - 1)) + current_offset += msum + A_incs = A_idx * block * block + A_incs[1:] -= A_idx[:-1] * block * block + A_incs = A_incs.view(-1, 1).repeat(1, div) + if trans: + A_incs[:, 1:] = step + A_incs[:, 0] -= (div - 1) * step + else: + A_incs[:, 1:] = step * block + A_incs[:, 0] -= (div - 1) * step * block + A_incs[offsets[segments > 0], 0] = A_idx[offsets[segments > 0]] + A_incs = A_incs.view(-1) + # create header + width = col_id.size(0) + offsets = offsets * 2 * div + 4 * width + segments = segments * div + header = torch.stack((offsets, segments, col_id, head_id), dim=1).view(-1).contiguous() + # create increments + incs = torch.stack((B_incs, A_incs), dim=1).view(-1).contiguous() + # pad by a factor 2*MAX_NUM_STAGES + # to accommodate pre-fetching inside the kernel + pad = torch.zeros(20, device=incs.device, dtype=incs.dtype) + incs = torch.cat((incs, pad)) + # create lut + lut = torch.cat((header, incs)) + lut = lut.type(torch.int32).to(device) + # create locks + return lut, width + +# ----------------------------- +# Dense = Dense x Sparse (DDS) +# ----------------------------- +# AB = (B^T A^T)^T + + +def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): + return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + +############## +# MAIN API # +############## + + +class _matmul(torch.autograd.Function): + + fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} + + @staticmethod + def forward( + ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, + c_lut, c_width, da_lut, da_width, db_lut, db_width, out + ): + c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) + # save for backward + ctx.save_for_backward(a, b) + ctx.da_lut = da_lut + ctx.da_width = da_width + ctx.db_lut = db_lut + ctx.db_width = db_width + ctx.mode = mode + ctx.spdims = spdims + ctx.block = block + ctx.trans_a = trans_a + ctx.trans_b = trans_b + ctx.trans_c = trans_c + ctx.has_out = out is not None + return c + + @staticmethod + def backward(ctx, dc): + # saved for backward + a, b = ctx.saved_tensors + da, db = None, None + mode = ctx.mode + # gradients w.r.t. a + if ctx.needs_input_grad[0]: + mode_da = mode[1] + mode[0] + mode[2] + da = _matmul.fn[mode_da]( + dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width, + ) + # gradients w.r.t. b + if ctx.needs_input_grad[1]: + mode_db = mode[2] + mode[1] + mode[0] + db = _matmul.fn[mode_db]( + a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width, + ) + dout = dc if ctx.has_out else None + return da, db, None, None, None,\ + None, None, None, None,\ + None, None, None, None, None, dout + + +class matmul: + + def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, trans_c=False): + if mode not in ['sdd', 'dsd', 'dds']: + raise NotImplementedError('Supported modes are: sdd, dsd, dds') + self.block = block + self.mode = mode + self.trans_a = trans_a + self.trans_b = trans_b + self.trans_c = trans_c + self.layout = layout + self.spdims = layout.shape + step = min(block, 32) + if self.mode == 'sdd': + self.c_lut, self.c_width = sdd_lut(layout, block, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, True, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, False, device) + if self.mode == 'dsd': + self.c_lut, self.c_width = dsd_lut(layout, block, step, not self.trans_a, device) + self.da_lut, self.da_width = sdd_lut(layout, block, device) + self.db_lut, self.db_width = dsd_lut(layout, block, step, self.trans_a, device) + if self.mode == 'dds': + self.c_lut, self.c_width = dsd_lut(layout, block, step, self.trans_b, device) + self.da_lut, self.da_width = dsd_lut(layout, block, step, not self.trans_b, device) + self.db_lut, self.db_width = sdd_lut(layout, block, device) + + def __call__(self, a, b, out=None): + c = _matmul.apply( + a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, + self.c_lut, self.c_width, + self.da_lut, self.da_width, + self.db_lut, self.db_width, + out + ) + return c diff --git a/pkgs/triton/ops/blocksparse/softmax.py b/pkgs/triton/ops/blocksparse/softmax.py new file mode 100644 index 0000000..ac2b7c3 --- /dev/null +++ b/pkgs/triton/ops/blocksparse/softmax.py @@ -0,0 +1,239 @@ +import torch + +import triton +import triton.language as tl + + +def num_warps(n): + if n <= 128: + return 1 + if n <= 256: + return 2 + if n <= 512: + return 4 + if n <= 4096: + return 8 + return 16 + + +@triton.jit +def _blocksparse_softmax_fwd( + Out, A, stride_xz, LUT, + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, + ROW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_DENSE: tl.constexpr, +): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # pointer offset + off_a = z * stride_xz + off_a += (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE # block indx + off_a += (m % BLOCK_SIZE) * BLOCK_SIZE # row indx + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=block_n < size, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load X + mask = block_n < size + a = tl.load(A + off_a + lane_n, mask=mask, other=-float("inf")) + a = a.to(tl.float32) + # compute + out = a + out *= scale + # apply relative attention + if R is not None: + R += z * stride_zr + R += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) + rel_logits = tl.load(R + m * extent + off_lo, mask=mask_lo, other=0.0) + out += rel_logits + out = out.to(tl.float32) + # apply causal mask + out = tl.where((ns > m) & is_causal, -float("inf"), out) + # computation + out = tl.softmax(out) + # write-back + tl.store(Out + off_a + lane_n, out, mask=mask) + + +@triton.jit +def _blocksparse_softmax_bwd( + DA, stride_zdx, + DOut, stride_zdout, + Out, stride_zout, + scale, + LUT, + DR, extent, stride_zr, stride_hr, stride_er, + is_causal, + ROW_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_DENSE: tl.constexpr, +): + h = tl.program_id(0) + m = tl.program_id(1) + z = tl.program_id(2) + # create index ranges + hm = h * tl.num_programs(1) + m + lane_n = tl.arange(0, ROW_SIZE) % BLOCK_SIZE + block_n = tl.arange(0, ROW_SIZE) // BLOCK_SIZE + # extract information from LUT + header = LUT + (hm // BLOCK_SIZE) * 2 + size = tl.load(header + 0) + offset = tl.load(header + 1) + # row-col offset + off_mn = (offset + block_n) * BLOCK_SIZE * BLOCK_SIZE + off_mn += (m % BLOCK_SIZE) * BLOCK_SIZE + mask = block_n < size + # pointers + As = Out + z * stride_zout + off_mn + DOuts = DOut + z * stride_zdout + off_mn + # do not need to read column indices in the dense case + if IS_DENSE: + ns = tl.arange(0, ROW_SIZE) + else: + off_lut = offset + 2 * tl.num_programs(0) * tl.num_programs(1) // BLOCK_SIZE + start_n = tl.load(LUT + off_lut + block_n, mask=mask, other=0) + ns = start_n * BLOCK_SIZE + lane_n + # load data + a = tl.load(As + lane_n, mask=mask, other=0.0) + a = a.to(tl.float32) + dout = tl.load(DOuts + lane_n, mask=mask, other=0.0) + dout = dout.to(tl.float32) + # compute + a = tl.where((ns > m) & is_causal & (a == a), 0., a) + da = a * (dout - tl.sum(a * dout, 0)) + # apply relative attention + if DR is not None: + DR += z * stride_zr + DR += h * stride_hr + off_lo = (extent - m - 1) + ns + mask_lo = (off_lo >= 0) & (off_lo < extent) & mask + tl.store(DR + m * extent + off_lo, da, mask=mask_lo) + da = da * scale + # convert da + # write-back + DAs = DA + z * stride_zdx + off_mn + tl.store(DAs + lane_n, da, mask=mask) + + +class _softmax(torch.autograd.Function): + @staticmethod + def make_lut(layout, block, device): + _empty = torch.tensor([], dtype=torch.int64, device=layout.device) + sizes = _empty.clone() + # sizes along rows + for h in range(layout.shape[0]): + sizes = torch.cat((sizes, layout[h, :, :].sum(-1))) + total_sizes = sizes * block + # offsets in block format + offsets = torch.zeros_like(sizes) + offsets[1:] = torch.cumsum(sizes[:-1], dim=0) + # block indices + columns = layout.nonzero(as_tuple=False)[:, 2] + header = torch.stack((sizes, offsets), dim=1).view(-1) + lut = torch.cat((header, columns)).type(torch.int32).to(device) + return lut, int(total_sizes.max()) + + @staticmethod + def forward( + ctx, a, scale, rel_logits, is_causal, + spdims, block, lut, maxlut, is_dense + ): + if scale is not None and isinstance(scale, torch.Tensor): + assert scale.device.type == "cpu" + scale = scale.item() + M = a.shape[0] + grid = [spdims[0], spdims[1] * block, M] + rel_shape = (1, 1, 1, 1) if rel_logits is None else rel_logits.shape + rel_strides = (1, 1, 1, 1) if rel_logits is None else rel_logits.stride() + # enqueue kernel + out = torch.empty_like(a) + _blocksparse_softmax_fwd[grid]( + out, a, a.stride(0), lut, + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn + scale, + is_causal, + BLOCK_SIZE=block, + ROW_SIZE=triton.next_power_of_2(maxlut), + IS_DENSE=is_dense, + num_warps=num_warps(maxlut) + ) + # save to context + # ctx.mark_dirty(x) + ctx.save_for_backward(out, lut) + ctx.spdims = spdims + ctx.block = block + ctx.maxlut = maxlut + ctx.scale = scale + ctx.rel_shape = rel_shape + ctx.rel_strides = rel_strides + ctx.rel_dtype = a.dtype + ctx.is_dense = is_dense + ctx.is_causal = is_causal + return out + + @staticmethod + def backward(ctx, dout): + # retrieve from context + out, lut = ctx.saved_tensors + # relative logits gradients + dr = None + if ctx.needs_input_grad[3]: + dr = torch.zeros(ctx.rel_shape, dtype=ctx.rel_dtype, device=out.device) + # run kernel + M = out.shape[0] + grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) + da = torch.empty_like(dout) + _blocksparse_softmax_bwd[grid]( + da, da.stride(0), + dout, dout.stride(0), + out, out.stride(0), + ctx.scale, + lut, + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], + ctx.is_causal, + BLOCK_SIZE=ctx.block, + ROW_SIZE=triton.next_power_of_2(ctx.maxlut), + IS_DENSE=ctx.is_dense, + num_warps=num_warps(ctx.maxlut) + ) + return (da, None, None, dr, None, + None, None, None, None, None, + None, + None, None, None, + None, + None, None, None + ) + + +class softmax: + def __init__(self, layout, block, device, is_dense=False): + self.spdims = layout.shape + self.layout = layout + self.block = block + self.lut, self.maxlut = _softmax.make_lut(self.layout, self.block, device) + self.is_dense = is_dense + + def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): + if rel_logits is not None and rel_logits.dtype != a.dtype: + raise ValueError(f"relative position embedding must be {a.dtype}") + a = _softmax.apply( + a, scale, rel_logits, is_causal, + self.spdims, self.block, self.lut, self.maxlut, self.is_dense, + ) + return a diff --git a/pkgs/triton/ops/bmm_matmul.py b/pkgs/triton/ops/bmm_matmul.py new file mode 100644 index 0000000..da00450 --- /dev/null +++ b/pkgs/triton/ops/bmm_matmul.py @@ -0,0 +1,163 @@ +import torch + +import triton +import triton.language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + +def get_configs_io_bound(): + configs = [] + for num_stages in [1]: +# TODO support block size 16 for MFMA dot op + for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 4 if block_n <= 64 else 8 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + #for split_k in [2, 4, 8, 16]: + # configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + +def get_configs_compute_bound(): + configs = [] + for block_m in [64, 128, 256]: + for block_n in [64, 128, 256]: + for block_k in [32, 64, 128]: + num_warps = 8 if block_n <= 64 else 16 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=1, num_warps=num_warps)) + return configs + + +@triton.autotune( + configs=[ + ] + get_configs_compute_bound() + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_K'] == 0, +}) +@triton.jit +def _bmm_kernel(A, B, C, M, N, K, + stride_aq, stride_am, stride_ak, + stride_bq, stride_bk, stride_bn, + stride_cq, stride_cm, stride_cn, + dot_out_dtype: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ): + pid = tl.program_id(0) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + idx_q = tl.program_id(1) # batch dimension for BMM + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(K, 0, -BLOCK_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * stride_ak + B += BLOCK_K * stride_bk + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_q = tl.program_id(1) # batch dimension for BMM + idx_m = rm[:, None] + idx_n = rn[None, :] + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn + idx_q * stride_cq) + mask = (idx_m < M) & (idx_n < N) + # handles write-back with reduction-splitting + tl.store(C, acc, mask=mask) + +class _bmm(torch.autograd.Function): + kernel = _bmm_kernel + + _locks = {} + + @staticmethod + def _call(a, b, dot_out_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + + #only MR support Trans layout + if hasattr(torch, "corex"): + capability = torch.cuda.get_device_capability(device) + capability = capability[0] * 10 + capability[1] + if (capability < 71): + if a.stride(0) >= 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) >= 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[0] == b.shape[0], "incompatible dimensions" + assert a.shape[2] == b.shape[1], "incompatible dimensions" + B, M, K = a.shape + _, _, N = b.shape + # allocates output + c = torch.empty((B, M, N), device=device, dtype=a.dtype) + if dot_out_dtype is None: + if a.dtype in [torch.float16, torch.float32, torch.bfloat16]: + dot_out_dtype = tl.float32 + else: + dot_out_dtype = tl.int32 + else: + assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype" + if dot_out_dtype == torch.float16: + dot_out_dtype = tl.float16 + elif dot_out_dtype in [torch.float32, torch.bfloat16]: + dot_out_dtype = tl.float32 + else: + dot_out_dtype = tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), B, 1) + _bmm_kernel[grid](a, b, c, M, N, K, + a.stride(0), a.stride(1), a.stride(2), + b.stride(0), b.stride(1), b.stride(2), + c.stride(0), c.stride(1), c.stride(2), + dot_out_dtype=dot_out_dtype, + GROUP_M=8) + return c + + @staticmethod + def forward(ctx, a, b, dot_out_dtype=None): + return _bmm._call(a, b, dot_out_dtype=dot_out_dtype) + +bmm = _bmm.apply diff --git a/pkgs/triton/ops/cross_entropy.py b/pkgs/triton/ops/cross_entropy.py new file mode 100644 index 0000000..f66cddf --- /dev/null +++ b/pkgs/triton/ops/cross_entropy.py @@ -0,0 +1,94 @@ +import torch + +import triton +import triton.language as tl + + +def num_warps(N): + if N < 2048: + return 4 + elif N < 8192: + return 8 + return 16 + + +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])}) +@triton.jit +def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to logit and probs + LOGITS = LOGITS + row * N + cols + WRIT_PROBS = PROBS + row * N + cols + READ_PROBS = PROBS + row * N + idx + # write-back negative log-probs + logits = tl.load(LOGITS, mask=cols < N, other=-float('inf')) + logits = logits.to(tl.float32) + logits = logits - tl.max(logits, 0) + probs = tl.log(tl.sum(tl.exp(logits), 0)) - logits + tl.store(WRIT_PROBS, probs, mask=cols < N) + # There is a bug in the compiler, which fails to insert a barrier here. + # We add it explicitly for now. Will be fixed soon. + tl.debug_barrier() + # write-back loss + probs = tl.load(READ_PROBS) + tl.store(LOSS + row, probs) + + +@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])}) +@triton.jit +def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): + row = tl.program_id(0) + cols = tl.arange(0, BLOCK) + idx = tl.load(IDX + row) + # pointers to probs + PROBS = PROBS + row * N + cols + # We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + # and we have -log(p[k]) stored in PROBS, so this is easy + probs = -tl.load(PROBS, mask=cols < N, other=float('inf')) + probs = tl.exp(probs.to(tl.float32)) + delta = cols == idx + # write result in-place in PROBS + dout = tl.load(DPROBS + row) + din = (probs - delta) * dout + tl.store(PROBS, din.to(PROBS.dtype.element_ty), mask=cols < N) + + +class _cross_entropy(torch.autograd.Function): + @classmethod + def forward(cls, ctx, logits, indices): + # make sure we can use triton + assert (indices.dtype == torch.int64), "Indices are expected to be of type long." + # make kernel + device, dtype = logits.device, logits.dtype + n_cols = logits.shape[-1] + # run the kernel + result = torch.empty_like(indices, dtype=dtype, device=device) + neg_logprobs = torch.empty_like(logits, dtype=dtype, device=device) + grid = lambda opt: (logits.numel() // n_cols, ) + _forward[grid](logits, neg_logprobs, indices, result, n_cols) + # save for backward + ctx.save_for_backward(neg_logprobs, indices) + return result + + @classmethod + def backward(cls, ctx, dneg_logprobs): + """We know d(-log(p[i])/dlogit[k] = -id_mat[i,k] + p[k] + so we initialize the gradient as neg_logprobs, so we can just exponentiate + to get p[k], which is most of what we need... neg_logprobs will be + modified in place to become the gradient we want + """ + # load saved tensors + neg_logprobs, indices = ctx.saved_tensors + # run the kernel + # neg_logprobs will be modified in place to become our gradient: + n_cols = neg_logprobs.shape[-1] + grid = lambda opt: (neg_logprobs.numel() // n_cols, ) + _backward[grid](neg_logprobs, indices, dneg_logprobs, n_cols) + return neg_logprobs, None + + +cross_entropy = _cross_entropy.apply diff --git a/pkgs/triton/ops/flash_attention.py b/pkgs/triton/ops/flash_attention.py new file mode 100644 index 0000000..d561567 --- /dev/null +++ b/pkgs/triton/ops/flash_attention.py @@ -0,0 +1,271 @@ +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +import torch + +import triton +import triton.language as tl +from triton.common.build import is_corex + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + L, M, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_prev = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + # -- compute qk ---- + k = tl.load(k_ptrs) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + # compute new m + m_curr = tl.maximum(tl.max(qk, 1), m_prev) + # correct old l + l_prev *= tl.exp(m_prev - m_curr) + # attention weights + p = tl.exp(qk - m_curr[:, None]) + l_curr = tl.sum(p, 1) + l_prev + # rescale operands of matmuls + l_rcp = 1. / l_curr + p *= l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + # update acc + p = p.to(Q.dtype.element_ty) + v = tl.load(v_ptrs) + acc += tl.dot(p, v) + # update m_i and l_i + l_prev = l_curr + m_prev = m_curr + # update pointers + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_prev) + tl.store(m_ptrs, m_prev) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, tl.trans(k)) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(Q.dtype.element_ty), k) + tl.store(dq_ptrs, dq) + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if not is_corex(): + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK = 128 + else: + BLOCK = 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases. + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 + + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, num_warps=num_warps, + num_stages=2 if not is_corex() else 1, + ) + + ctx.save_for_backward(q, k, v, o, L, m) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + BLOCK = 128 if not is_corex() else 64 # FIXME: currently BLOCK=128 has issues, BLOCK=64 works for common cases. + num_warps = 16 if is_corex() and ctx.BLOCK_DMODEL > 64 else 8 + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, + num_stages=1, + ) + return dq, dk, dv, None + + +attention = _attention.apply diff --git a/pkgs/triton/ops/matmul.py b/pkgs/triton/ops/matmul.py new file mode 100644 index 0000000..5bda4ab --- /dev/null +++ b/pkgs/triton/ops/matmul.py @@ -0,0 +1,184 @@ +import torch + +import triton +import triton.language as tl +from .matmul_perf_model import early_config_prune, estimate_matmul_time + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + +def get_configs_io_bound(): + configs = [] + if hasattr(torch, "corex"): + return configs + for num_stages in [1, 2]: +# TODO support block size 16 for MFMA dot op + for block_m in [16, 32] if torch.version.hip is None and not hasattr(torch, "corex") else [32, 64]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append(triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + return configs + +def get_configs_compute_bound(): + configs = [] + if hasattr(torch, "corex"): + for block_m in [32, 64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for block_k in [32, 64, 128, 256]: + for num_stages in [1, 2]: + num_warps = 16 if block_m >= 128 or block_n >=128 or block_k >= 128 else 8 + configs.append( + triton.Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': 1}, + num_stages=num_stages, num_warps=num_warps)) + return configs + +def get_nv_config(): + configs = [] + if hasattr(torch, "corex"): + return configs + configs = [# basic configs for compute-bound matmuls + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 32, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + # good for int8 + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 128, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 32, 'BLOCK_K': 64, 'SPLIT_K': 1}, num_stages=5, num_warps=2), + ] + return configs + +@triton.autotune( + configs=get_nv_config() + get_configs_compute_bound() + get_configs_io_bound(), + key=['M', 'N', 'K'], + prune_configs_by={ + 'early_config_prune': early_config_prune, + 'perf_model': estimate_matmul_time, + 'top_k': 10 + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _kernel(A, B, C, M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + dot_out_dtype: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, + ): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.) + acc += tl.dot(a, b, out_dtype=dot_out_dtype) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +class _matmul(torch.autograd.Function): + kernel = _kernel + + _locks = {} + + @staticmethod + def _call(a, b, dot_out_dtype): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + # allocates output + c = torch.empty((M, N), device=device, dtype=a.dtype) + if dot_out_dtype is None: + if a.dtype in [torch.float16, torch.float32, torch.bfloat16]: + dot_out_dtype = tl.float32 + else: + dot_out_dtype = tl.int32 + else: + assert isinstance(dot_out_dtype, torch.dtype), "dot_out_dtype must be a torch.dtype" + if dot_out_dtype == torch.float16: + dot_out_dtype = tl.float16 + elif dot_out_dtype in [torch.float32, torch.bfloat16]: + dot_out_dtype = tl.float32 + else: + dot_out_dtype = tl.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + _kernel[grid](a, b, c, M, N, K, + a.stride(0), a.stride(1), + b.stride(0), b.stride(1), + c.stride(0), c.stride(1), + dot_out_dtype=dot_out_dtype, + GROUP_M=8) + return c + + @staticmethod + def forward(ctx, a, b, dot_out_dtype=None): + return _matmul._call(a, b, dot_out_dtype=dot_out_dtype) + + +matmul = _matmul.apply diff --git a/pkgs/triton/ops/matmul_perf_model.py b/pkgs/triton/ops/matmul_perf_model.py new file mode 100644 index 0000000..dea85c6 --- /dev/null +++ b/pkgs/triton/ops/matmul_perf_model.py @@ -0,0 +1,164 @@ +import heapq + +import torch + +import triton +import triton._C.libtriton.triton as _triton +from triton.runtime import driver +from triton.testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops + + +def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) + return tflops + + +def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): + ''' return compute throughput in TOPS ''' + total_warps = num_ctas * min(num_warps, 4) + num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) + return tflops + + +def get_tflops(backend, device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(backend, device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype) + + +def estimate_matmul_time( + # backend, device, + num_warps, num_stages, + A, B, C, + M, N, K, + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, + debug=False, **kwargs +): + ''' return estimated running time in ms + = max(compute, loading) + store ''' + backend = _triton.runtime.backend.CUDA + device = torch.cuda.current_device() + dtype = A.dtype + dtsize = A.element_size() + + num_cta_m = triton.cdiv(M, BLOCK_M) + num_cta_n = triton.cdiv(N, BLOCK_N) + num_cta_k = SPLIT_K + num_ctas = num_cta_m * num_cta_n * num_cta_k + + # If the input is smaller than the block size + M, N = max(M, BLOCK_M), max(N, BLOCK_N) + + # time to compute + total_ops = 2 * M * N * K / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(backend, device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(backend, device) * (active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = M * K * dtsize * (1 + 0.2 * (num_cta_n - 1)) + load_a_l2 = M * K * dtsize * 0.8 * (num_cta_n - 1) + load_b_dram = N * K * dtsize * (1 + 0.2 * (num_cta_m - 1)) + load_b_l2 = N * K * dtsize * 0.8 * (num_cta_m - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.6 # :o + store_c_dram = M * N * dtsize * SPLIT_K / (1024 * 1024) # MB + if SPLIT_K == 1: + store_ms = store_c_dram / store_bw + else: + reduce_bw = store_bw + store_ms = store_c_dram / reduce_bw + # c.zero_() + zero_ms = M * N * 2 / (1024 * 1024) / store_bw + store_ms += zero_ms + + total_time_ms = compute_ms + load_ms + store_ms + if debug: + print(f'Total time: {total_time_ms}ms, compute time: {compute_ms}ms, ' + f'loading time: {load_ms}ms, store time: {store_ms}ms, ' + f'Activate CTAs: {active_cta_ratio*100}%') + return total_time_ms + + +def early_config_prune(configs, named_args): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['A'].element_size() + dtype = named_args['A'].dtype + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], config.num_stages + + max_shared_memory = driver.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + + # Some dtypes do not allow atomic_add + if dtype not in [torch.float16, torch.float32]: + configs = [config for config in configs if config.kwargs['SPLIT_K'] == 1] + + # group configs by (BLOCK_M,_N,_K, SPLIT_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages = \ + kw['BLOCK_M'], kw['BLOCK_N'], kw['BLOCK_K'], kw['SPLIT_K'], config.num_warps, config.num_stages + + key = (BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_M * BLOCK_N * BLOCK_K / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + if hasattr(torch, "corex"): + for stage in range(len(v)): + random_config = v[stage][0] + random_config.num_stages = v[stage][1] + pruned_configs.append(random_config) + else: + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs diff --git a/pkgs/triton/runtime/__init__.py b/pkgs/triton/runtime/__init__.py new file mode 100644 index 0000000..a4291ab --- /dev/null +++ b/pkgs/triton/runtime/__init__.py @@ -0,0 +1,21 @@ +from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, + heuristics) +from .driver import driver +from .jit import (JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret, + version_key) + +__all__ = [ + "driver", + "Config", + "Heuristics", + "autotune", + "heuristics", + "JITFunction", + "KernelInterface", + "version_key", + "reinterpret", + "TensorWrapper", + "OutOfResources", + "MockTensor", + "Autotuner", +] diff --git a/pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..d9fa312 Binary files /dev/null and b/pkgs/triton/runtime/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc b/pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc new file mode 100644 index 0000000..d1d5a24 Binary files /dev/null and b/pkgs/triton/runtime/__pycache__/autotuner.cpython-310.pyc differ diff --git a/pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc b/pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc new file mode 100644 index 0000000..37abc27 Binary files /dev/null and b/pkgs/triton/runtime/__pycache__/cache.cpython-310.pyc differ diff --git a/pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc b/pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc new file mode 100644 index 0000000..94bfad6 Binary files /dev/null and b/pkgs/triton/runtime/__pycache__/driver.cpython-310.pyc differ diff --git a/pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc b/pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc new file mode 100644 index 0000000..04f30b2 Binary files /dev/null and b/pkgs/triton/runtime/__pycache__/errors.cpython-310.pyc differ diff --git a/pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc b/pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc new file mode 100644 index 0000000..9ddb22d Binary files /dev/null and b/pkgs/triton/runtime/__pycache__/jit.cpython-310.pyc differ diff --git a/pkgs/triton/runtime/autotuner.py b/pkgs/triton/runtime/autotuner.py new file mode 100644 index 0000000..c44e266 --- /dev/null +++ b/pkgs/triton/runtime/autotuner.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import builtins +import time +from typing import Dict + +import json +import os +import hashlib + +from ..testing import do_bench +from .jit import KernelInterface +from .cache import default_cache_dir + + +def build_best_config_hash(args_names, key): + cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + hasher = hashlib.sha256() + hasher.update(f"{'_'.join(args_names) + str(key)}\n".encode()) + cfg_hash = hasher.hexdigest() + cfg_hash_dir = os.path.join(cache_dir, cfg_hash) + cfg_hash_file = os.path.splitext(cfg_hash)[0] + ".best_config" + cfg_hash_file = os.path.join(cfg_hash_dir, cfg_hash_file) + return cfg_hash_dir, cfg_hash_file + + +def load_best_config(args_names, key): + _, cfg_hash_file = build_best_config_hash(args_names, key) + if os.path.exists(cfg_hash_file): + with open(cfg_hash_file) as fd: + best_config = json.loads(fd.read()) + num_warps = best_config.pop('num_warps') if 'num_warps' in best_config else 4 + num_stages = best_config.pop('num_stages') if 'num_stages' in best_config else 1 + return best_config, num_warps, num_stages + return None + + +def save_best_config(cfg, args_names, key): + cfg_hash_dir, cfg_hash_file = build_best_config_hash(args_names, key) + if os.path.exists(cfg_hash_dir): + return + os.makedirs(cfg_hash_dir, exist_ok=True) + with open(cfg_hash_file, "w") as fd: + fd.write( + json.dumps( + { + **cfg.kwargs, + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + } + ) + ) + + +class OutOfResources(Exception): + def __init__(self, required, limit, name): + self.message = f'out of resource: {name}, '\ + f'Required: {required}, '\ + f'Hardware limit: {limit}' + self.message += '. Reducing block sizes or `num_stages` may help.' + self.required = required + self.limit = limit + self.name = name + super().__init__(self.message) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) + + +class Autotuner(KernelInterface): + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + ''' + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + try: + return do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except OutOfResources: + return [float('inf'), float('inf'), float('inf')] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + divisibility = 16 + for arg in args: + if hasattr(arg, "data_ptr"): + key.append(arg.dtype) + key.append(arg.data_ptr() % divisibility == 0) + elif isinstance(arg, int): + key.append(arg) + key = tuple(key) + if key not in self.cache: + load_config = load_best_config(self.arg_names, key) + if load_config: + best_config, num_warps, num_stages = load_config + config = Config(best_config, num_warps, num_stages) + self.cache[key] = config + self.hook(args) + else: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + save_best_config(self.cache[key], self.arg_names, key) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + self.nargs = None + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, + num_warps=config.num_warps) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_stages: int + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f'{k}: {v}') + res.append(f'num_warps: {self.num_warps}') + res.append(f'num_stages: {self.num_stages}') + return ', '.join(res) + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/pkgs/triton/runtime/backends/cuda.c b/pkgs/triton/runtime/backends/cuda.c new file mode 100644 index 0000000..009d526 --- /dev/null +++ b/pkgs/triton/runtime/backends/cuda.c @@ -0,0 +1,131 @@ +#include "cuda.h" +#define PY_SSIZE_T_CLEAN +#include + +static inline void gpuAssert(CUresult code, const char *file, int line) { + if (code != CUDA_SUCCESS) { + const char *prefix = "Triton Error [CUDA]: "; + const char *str; + cuGetErrorString(code, &str); + char err[1024] = {0}; + strcat(err, prefix); + strcat(err, str); + PyErr_SetString(PyExc_RuntimeError, err); + } +} + +#define CUDA_CHECK(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + if (PyErr_Occurred()) \ + return NULL; \ + } + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + // Get device handle + CUdevice device; + cuDeviceGet(&device, device_id); + + // create a struct to hold device properties + int max_shared_mem; + int multiprocessor_count; + int sm_clock_rate; + int mem_clock_rate; + int mem_bus_width; + CUDA_CHECK(cuDeviceGetAttribute( + &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + CUDA_CHECK(cuDeviceGetAttribute( + &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device)); + CUDA_CHECK(cuDeviceGetAttribute(&sm_clock_rate, + CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device)); + CUDA_CHECK(cuDeviceGetAttribute( + &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device)); + CUDA_CHECK(cuDeviceGetAttribute( + &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device)); + + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + max_shared_mem, "multiprocessor_count", + multiprocessor_count, "sm_clock_rate", sm_clock_rate, + "mem_clock_rate", mem_clock_rate, "mem_bus_width", + mem_bus_width); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + CUfunction fun; + CUmodule mod; + int32_t n_regs = 0; + int32_t n_spills = 0; + // create driver handles + CUcontext pctx = 0; + CUDA_CHECK(cuCtxGetCurrent(&pctx)); + if (!pctx) { + CUDA_CHECK(cuDevicePrimaryCtxRetain(&pctx, device)); + CUDA_CHECK(cuCtxSetCurrent(pctx)); + } + + CUDA_CHECK(cuModuleLoadData(&mod, data)); + CUDA_CHECK(cuModuleGetFunction(&fun, mod, name)); + // get allocated registers and spilled registers from the function + CUDA_CHECK(cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun)); + CUDA_CHECK( + cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun)); + n_spills /= 4; + // set dynamic shared memory if necessary + int shared_optin; + CUDA_CHECK(cuDeviceGetAttribute( + &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, + device)); + if (shared > 49152 && shared_optin > 49152) { + CUDA_CHECK(cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED)); + int shared_total, shared_static; + CUDA_CHECK(cuDeviceGetAttribute( + &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, + device)); + CUDA_CHECK(cuFuncGetAttribute(&shared_static, + CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun)); + CUDA_CHECK( + cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, + shared_optin - shared_static)); + } + + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided cubin into CUDA driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_cuda_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/pkgs/triton/runtime/backends/hip.c b/pkgs/triton/runtime/backends/hip.c new file mode 100644 index 0000000..5ed5f19 --- /dev/null +++ b/pkgs/triton/runtime/backends/hip.c @@ -0,0 +1,120 @@ +#define __HIP_PLATFORM_AMD__ +#include +#define PY_SSIZE_T_CLEAN +#include +#include +#include + +static inline void gpuAssert(hipError_t code, const char *file, int line) { + { + if (code != HIP_SUCCESS) { + { + const char *prefix = "Triton Error [HIP]: "; + const char *str = hipGetErrorString(code); + char err[1024] = {0}; + snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str); + PyErr_SetString(PyExc_RuntimeError, err); + } + } + } +} + +#define HIP_CHECK(ans) \ + { \ + gpuAssert((ans), __FILE__, __LINE__); \ + if (PyErr_Occurred()) \ + return NULL; \ + } + +static PyObject *getDeviceProperties(PyObject *self, PyObject *args) { + int device_id; + if (!PyArg_ParseTuple(args, "i", &device_id)) + return NULL; + + hipDeviceProp_t props; + HIP_CHECK(hipGetDeviceProperties(&props, device_id)); + + // create a struct to hold device properties + return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i}", "max_shared_mem", + props.sharedMemPerBlock, "multiprocessor_count", + props.multiProcessorCount, "sm_clock_rate", + props.clockRate, "mem_clock_rate", props.memoryClockRate, + "mem_bus_width", props.memoryBusWidth); +} + +static PyObject *loadBinary(PyObject *self, PyObject *args) { + const char *name; + const char *data; + Py_ssize_t data_size; + int shared; + int device; + if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared, + &device)) { + return NULL; + } + + // Open HSACO file + FILE *hsaco_file; + if ((hsaco_file = fopen(data, "rb")) == NULL) { + return NULL; + } + + // Read HSCAO file into Buffer + fseek(hsaco_file, 0L, SEEK_END); + size_t hsaco_file_size = ftell(hsaco_file); + unsigned char *hsaco = + (unsigned char *)malloc(hsaco_file_size * sizeof(unsigned char)); + rewind(hsaco_file); + fread(hsaco, sizeof(unsigned char), hsaco_file_size, hsaco_file); + fclose(hsaco_file); + + // set HIP options + hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes, + hipJitOptionErrorLogBuffer, + hipJitOptionInfoLogBufferSizeBytes, + hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose}; + const unsigned int errbufsize = 8192; + const unsigned int logbufsize = 8192; + char _err[errbufsize]; + char _log[logbufsize]; + void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err, + (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1}; + + // launch HIP Binary + hipModule_t mod; + hipFunction_t fun; + hipModuleLoadDataEx(&mod, hsaco, 5, opt, optval); + hipModuleGetFunction(&fun, mod, name); + free(hsaco); + + // get allocated registers and spilled registers from the function + int n_regs = 0; + int n_spills = 0; + if (PyErr_Occurred()) { + return NULL; + } + return Py_BuildValue("(KKii)", (uint64_t)mod, (uint64_t)fun, n_regs, + n_spills); +} + +static PyMethodDef ModuleMethods[] = { + {"load_binary", loadBinary, METH_VARARGS, + "Load provided hsaco into HIP driver"}, + {"get_device_properties", getDeviceProperties, METH_VARARGS, + "Get the properties for a given device"}, + {NULL, NULL, 0, NULL} // sentinel +}; + +static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils", + NULL, // documentation + -1, // size + ModuleMethods}; + +PyMODINIT_FUNC PyInit_hip_utils(void) { + PyObject *m = PyModule_Create(&ModuleDef); + if (m == NULL) { + return NULL; + } + PyModule_AddFunctions(m, ModuleMethods); + return m; +} diff --git a/pkgs/triton/runtime/cache.py b/pkgs/triton/runtime/cache.py new file mode 100644 index 0000000..43e6660 --- /dev/null +++ b/pkgs/triton/runtime/cache.py @@ -0,0 +1,131 @@ +import json +import os +import random +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, Optional + + +def default_cache_dir(): + return os.path.join(Path.home(), ".triton", "cache") + + +class CacheManager(ABC): + def __init__(self, key): + pass + + @abstractmethod + def get_file(self, filename) -> Optional[str]: + pass + + @abstractmethod + def has_file(self, filename) -> bool: + pass + + @abstractmethod + def put(self, data, filename, binary=True) -> str: + pass + + @abstractmethod + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + pass + + @abstractmethod + def put_group(self, filename: str, group: Dict[str, str]): + pass + + +class FileCacheManager(CacheManager): + def __init__(self, key): + self.key = key + self.lock_path = None + # create cache directory if it doesn't exist + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', default_cache_dir()) + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + + def _make_path(self, filename) -> str: + return os.path.join(self.cache_dir, filename) + + def has_file(self, filename): + if not self.cache_dir: + return False + return os.path.exists(self._make_path(filename)) + + def get_file(self, filename) -> Optional[str]: + if self.has_file(filename): + return self._make_path(filename) + else: + return None + + def get_group(self, filename: str) -> Optional[Dict[str, str]]: + grp_filename = f"__grp__{filename}" + if not self.has_file(grp_filename): + return None + grp_filepath = self._make_path(grp_filename) + with open(grp_filepath) as f: + grp_data = json.load(f) + child_paths = grp_data.get("child_paths", None) + # Invalid group data. + if child_paths is None: + return None + result = {} + for c in child_paths: + p = self._make_path(c) + if not os.path.exists(p): + raise Exception(f"Group file {p} does not exist from group {grp_filename} ") + result[c] = p + return result + + # Note a group of pushed files as being part of a group + def put_group(self, filename: str, group: Dict[str, str]): + if not self.cache_dir: + return + grp_contents = json.dumps({"child_paths": sorted(list(group.keys()))}) + grp_filename = f"__grp__{filename}" + return self.put(grp_contents, grp_filename, binary=False) + + def put(self, data, filename, binary=True) -> str: + if not self.cache_dir: + return + binary = isinstance(data, bytes) + if not binary: + data = str(data) + assert self.lock_path is not None + filepath = self._make_path(filename) + # Random ID to avoid any collisions + rnd_id = random.randint(0, 1000000) + # we use the PID incase a bunch of these around so we can see what PID made it + pid = os.getpid() + # use tempfile to be robust against program interruptions + temp_path = f"{filepath}.tmp.pid_{pid}_{rnd_id}" + mode = "wb" if binary else "w" + with open(temp_path, mode) as f: + f.write(data) + # Replace is guaranteed to be atomic on POSIX systems if it succeeds + # so filepath cannot see a partial write + os.replace(temp_path, filepath) + return filepath + + +__cache_cls = FileCacheManager +__cache_cls_nme = "DEFAULT" + + +def get_cache_manager(key) -> CacheManager: + import os + + user_cache_manager = os.environ.get("TRITON_CACHE_MANAGER", None) + global __cache_cls + global __cache_cls_nme + + if user_cache_manager is not None and user_cache_manager != __cache_cls_nme: + import importlib + module_path, clz_nme = user_cache_manager.split(":") + module = importlib.import_module(module_path) + __cache_cls = getattr(module, clz_nme) + __cache_cls_nme = user_cache_manager + + return __cache_cls(key) diff --git a/pkgs/triton/runtime/driver.py b/pkgs/triton/runtime/driver.py new file mode 100644 index 0000000..5936599 --- /dev/null +++ b/pkgs/triton/runtime/driver.py @@ -0,0 +1,175 @@ +import abc +import hashlib +import os +import tempfile +from pathlib import Path + +from ..common.build import _build +from .cache import get_cache_manager + + +class DriverBase(metaclass=abc.ABCMeta): + + CUDA = 0 + HIP = 1 + + @staticmethod + def third_party_dir(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "third_party") + + def __init__(self) -> None: + pass +# ----------------------------- +# CUDA +# ----------------------------- + + +class CudaUtils(object): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(CudaUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + dirname = os.path.dirname(os.path.realpath(__file__)) + src = Path(os.path.join(dirname, "backends", "cuda.c")).read_text() + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + fname = "cuda_utils.so" + cache_path = cache.get_file(fname) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build("cuda_utils", src_path, tmpdir) + cache.put(src, "main.c", binary=False) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), fname, binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location("cuda_utils", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +class CudaDriver(DriverBase): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(CudaDriver, cls).__new__(cls) + return cls.instance + + def __init__(self): + self.utils = CudaUtils() + self.backend = self.CUDA + +# ----------------------------- +# HIP +# ----------------------------- + + +class HIPUtils(object): + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(HIPUtils, cls).__new__(cls) + return cls.instance + + def __init__(self): + dirname = os.path.dirname(os.path.realpath(__file__)) + src = Path(os.path.join(dirname, "backends", "hip.c")).read_text() + key = hashlib.md5(src.encode("utf-8")).hexdigest() + cache = get_cache_manager(key) + fname = "hip_utils.so" + cache_path = cache.get_file(fname) + if cache_path is None: + with tempfile.TemporaryDirectory() as tmpdir: + src_path = os.path.join(tmpdir, "main.c") + with open(src_path, "w") as f: + f.write(src) + so = _build("hip_utils", src_path, tmpdir) + with open(so, "rb") as f: + cache_path = cache.put(f.read(), fname, binary=True) + import importlib.util + spec = importlib.util.spec_from_file_location("hip_utils", cache_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + self.load_binary = mod.load_binary + self.get_device_properties = mod.get_device_properties + + +class HIPDriver(DriverBase): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(HIPDriver, cls).__new__(cls) + return cls.instance + + def __init__(self): + self.utils = HIPUtils() + self.backend = self.HIP + + +class UnsupportedDriver(DriverBase): + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(UnsupportedDriver, cls).__new__(cls) + return cls.instance + + def __init__(self): + self.utils = None + self.backend = None + +# ----------------------------- +# Driver +# ----------------------------- + + +class LazyProxy: + def __init__(self, init_fn): + self._init_fn = init_fn + self._obj = None + + def _initialize_obj(self): + if self._obj is None: + self._obj = self._init_fn() + + def __getattr__(self, name): + self._initialize_obj() + return getattr(self._obj, name) + + def __setattr__(self, name, value): + if name in ['_init_fn', '_obj']: + super().__setattr__(name, value) + else: + self._initialize_obj() + setattr(self._obj, name, value) + + def __delattr__(self, name): + self._initialize_obj() + delattr(self._obj, name) + + def __repr__(self): + if self._obj is None: + return f"<{self.__class__.__name__} for {self._init_fn} not yet initialized>" + return repr(self._obj) + + def __str__(self): + self._initialize_obj() + return str(self._obj) + + +def initialize_driver(): + import torch + if torch.version.hip is not None: + return HIPDriver() + elif torch.cuda.is_available(): + return CudaDriver() + else: + return UnsupportedDriver() + + +driver = LazyProxy(initialize_driver) diff --git a/pkgs/triton/runtime/errors.py b/pkgs/triton/runtime/errors.py new file mode 100644 index 0000000..4ff9005 --- /dev/null +++ b/pkgs/triton/runtime/errors.py @@ -0,0 +1,15 @@ + +class OutOfResources(Exception): + def __init__(self, required, limit, name): + self.message = f'out of resource: {name}, '\ + f'Required: {required}, '\ + f'Hardware limit: {limit}' + self.message += '. Reducing block sizes or `num_stages` may help.' + self.required = required + self.limit = limit + self.name = name + super().__init__(self.message) + + def __reduce__(self): + # this is necessary to make CompilationError picklable + return (type(self), (self.required, self.limit, self.name)) diff --git a/pkgs/triton/runtime/jit.py b/pkgs/triton/runtime/jit.py new file mode 100644 index 0000000..f469781 --- /dev/null +++ b/pkgs/triton/runtime/jit.py @@ -0,0 +1,573 @@ +from __future__ import annotations, division + +import ast +import functools +import hashlib +import inspect +import os +import subprocess +import textwrap +from collections import defaultdict, namedtuple +from typing import Callable, Generic, Iterable, Optional, TypeVar, Union, cast, overload + +import torch +import triton + + +def get_disable_sme(): + disable_sme = os.getenv("TRITON_DISABLE_SME", default="0") + cc = torch.cuda.get_device_capability() + cc = cc[0] * 10 + cc[1] + if cc == 70: # for ivcore10 + disable_sme = "1" + + return disable_sme + + +def get_corex_sme(args, tl_args, enable_sme=True): + can_use_sme = 0 + if not enable_sme: + return can_use_sme + import torch + if not (hasattr(torch, "corex") and torch.corex == True): + return can_use_sme + close_sme = get_disable_sme() + if close_sme == "1": + return can_use_sme + index = 0 + for i, arg_name in enumerate(args): + arg = args.get(arg_name) + if (i in tl_args): + continue + if (isinstance(arg, int) and arg == 1): + continue + if torch.is_tensor(arg) and arg.dtype in [torch.float16, torch.float32, torch.bfloat16, torch.int8] and arg.dim() >= 2: + dim_M = arg.shape[-2] + dim_K = arg.shape[-1] + sme_dim = 64 / arg.element_size() + if (arg.is_contiguous() and dim_K % sme_dim == 0) or \ + (not arg.is_contiguous() and dim_M % sme_dim == 0): + can_use_sme = (1 << index) | can_use_sme + index += 1 + return can_use_sme + + +def get_cuda_stream(idx=None): + if idx is None: + idx = get_current_device() + try: + from torch._C import _cuda_getCurrentRawStream + return _cuda_getCurrentRawStream(idx) + except ImportError: + import torch + return torch.cuda.current_stream(idx).cuda_stream + + +def get_current_device(): + import torch + return torch.cuda.current_device() + + +def set_current_device(idx): + import torch + torch.cuda.set_device(idx) + + +def get_device_capability(idx): + import torch + return torch.cuda.get_device_capability(idx) + + +T = TypeVar('T') + +# ----------------------------------------------------------------------------- +# Dependencies Finder +# ----------------------------------------------------------------------------- + + +class DependenciesFinder(ast.NodeVisitor): + """ + This AST visitor is used to find dependencies of a JITFunction. This can + be used to invalidate a JITFunction's hash when its source code -- or + that of its dependencies -- changes. + """ + + def __init__(self, globals, src) -> None: + super().__init__() + self.ret = hashlib.md5(src.encode("utf-8")).hexdigest() + self.globals = globals + + def visit_Name(self, node): + return self.globals.get(node.id, None) + + def visit_Attribute(self, node): + lhs = self.visit(node.value) + while isinstance(lhs, ast.Attribute): + lhs = self.visit(lhs.value) + if lhs is None or lhs is triton: + return None + return getattr(lhs, node.attr) + + def visit_Call(self, node): + func = self.visit(node.func) + if func is None: + return + if inspect.isbuiltin(func): + return + if func.__module__ and func.__module__.startswith('triton.'): + return + assert isinstance(func, JITFunction), f"Function \"{func.__name__}\" is being called from a Triton function but is not a Triton function itself. Decorate it with @triton.jit to fix this" + if func.hash is None: + tree = ast.parse(func.src) + finder = DependenciesFinder(func.__globals__, func.src) + finder.visit(tree) + func.hash = finder.ret + noinline = str(getattr(func, 'noinline', False)) + self.ret = (self.ret + func.hash + noinline).encode("utf-8") + self.ret = hashlib.md5(self.ret).hexdigest() + +# ----------------------------------------------------------------------------- +# JITFunction +# ----------------------------------------------------------------------------- + + +@functools.lru_cache() +def version_key(): + import pkgutil + contents = [] + # frontend + with open(__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # compiler + compiler_path = os.path.join(*triton.__path__, 'compiler') + for lib in pkgutil.iter_modules([compiler_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # backend + with open(triton._C.libtriton.__file__, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # language + language_path = os.path.join(*triton.__path__, 'language') + for lib in pkgutil.iter_modules([language_path]): + with open(lib.module_finder.find_spec(lib.name).origin, "rb") as f: + contents += [hashlib.md5(f.read()).hexdigest()] + # ptxas version + try: + ptxas_version = hashlib.md5(subprocess.check_output(["ptxas", "--version"])).hexdigest() + except Exception: + ptxas_version = '' + return '-'.join(triton.__version__) + '-' + ptxas_version + '-' + '-'.join(contents) + + +class KernelInterface(Generic[T]): + run: T + + def __getitem__(self, grid) -> T: + """ + A JIT function is launched with: fn[grid](*args, **kwargs). + Hence JITFunction.__getitem__ returns a callable proxy that + memorizes the grid. + """ + return cast(T, functools.partial(cast(Callable, self.run), grid=grid)) + + +class JITFunction(KernelInterface[T]): + + # Hook for inspecting compiled functions and modules + cache_hook = None + divisibility = 16 + + @staticmethod + def _key_of(arg): + if hasattr(arg, "dtype"): + return arg.dtype + elif isinstance(arg, bool): + return "i1" + elif isinstance(arg, int): + if -2**31 <= arg and arg <= 2**31 - 1: + return "i32" + elif 2**63 <= arg and arg <= 2**64 - 1: + return "u64" + else: + return "i64" + elif isinstance(arg, float): + return 'fp32' + elif arg is None: + return None + else: + raise TypeError(f'Unsupported type {type(arg)} for {arg}') + + @staticmethod + def _spec_of(arg): + if hasattr(arg, "data_ptr"): + return (arg.data_ptr() % JITFunction.divisibility == 0) + elif isinstance(arg, int): + return (arg % 16 == 0, arg == 1) + return (arg is None, ) + + def _get_config(self, *args): + def is_divisible_by_16(x): + if hasattr(x, "data_ptr"): + return x.data_ptr() % JITFunction.divisibility == 0 + elif isinstance(x, int): + return x % JITFunction.divisibility == 0 + if x is None: + return True + return False + divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} + equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) + # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1) + + @staticmethod + def _type_of(key): + # None are nullptr -- implicitly converted to *i8 + if key is None: + return '*i8' + dtype_str = str(key).split(".")[-1] + tys = { + "bool": "i1", + "float8e5": "fp8e5", + "float8e4": "fp8e4", + "float16": "fp16", + "bfloat16": "bf16", + "float32": "fp32", + "float64": "fp64", + "int8": "i8", + "int16": "i16", + "int32": "i32", + "int64": "i64", + "uint8": "u8", + "uint16": "u16", + "uint32": "u32", + "uint64": "u64", + } + # reinterpret can create triton type + for v in list(tys.values()): + tys[v] = v + return key if isinstance(key, str) else f"*{tys[dtype_str]}" + + def _make_signature(self, sig_key): + signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) + return signature + + def _make_constants(self, constexpr_key): + constants = dict(zip(self.constexprs, constexpr_key)) + return constants + + def _call_hook(self, key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + if JITFunction.cache_hook is None: + return False + name = self.fn.__name__ + module = self.fn.__module__ + arg_reprs = ', '.join([f'{name}: {ty}' for name, ty in zip(self.arg_names, key[1])]) + repr = f"{name}[num_warps={num_warps}, num_stages={num_stages}]({arg_reprs})" + key = str(key) + + class LegacyCompiler: + def __init__(self, module, name): + self.module = module + self.name = name + pass + + kwargs = dict(signature=signature, device=device, constants=constants, + num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, + configs=configs) + + return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={"key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) + + def _get_arg_specialization_key(self, arg) -> str: + arg_annotation = self.__annotations__.get(arg, '') + if arg_annotation == '': + return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \ + else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) \ + else (False,)' + elif 'Tensor' in arg_annotation: + return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)' + elif arg_annotation == 'int': + return f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)' + else: + return '(False,)' + + def _get_arg_sig_key(self, arg) -> str: + arg_annotation = self.__annotations__.get(arg, '') + if 'Tensor' in arg_annotation: + return f'{arg}.dtype' + elif arg_annotation == 'bool': + return "i1" + elif arg_annotation == 'float': + return 'fp32' + else: + return f'_key_of({arg})' + + def _make_launcher(self): + regular_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i not in self.constexprs] + constexpr_args = [f'{arg}' for i, arg in enumerate(self.arg_names) if i in self.constexprs] + args = ', '.join(regular_args) + # cache key for regular argument type + sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args]) + # cache key for constexpr argument values + constexpr_keys = ', '.join(constexpr_args) + # cache key for argument specialization + specializations = [] + for i, arg in enumerate(regular_args): + if i in self.do_not_specialize: + continue + specializations += [self._get_arg_specialization_key(arg)] + + spec_keys = ', '.join(specializations) + grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) + + src = f""" +def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, enable_sme=True, extern_libs=None, stream=None, warmup=False, device=None): + sig_key = {sig_keys}, + constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()} + spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()} + use_sme = get_corex_sme({{{grid_args}}}, self.constexprs, enable_sme) + key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug, use_sme) + if not extern_libs is None: + key = (key, tuple(extern_libs.items())) + assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2" + if callable(grid): + grid = grid({{{grid_args}}}) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + if device is None: + device = get_current_device() + set_current_device(device) + if stream is None and not warmup: + stream = get_cuda_stream(device) + bin = cache[device].get(key, None) + if bin is not None: + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) + return bin + # kernel not cached -- compile + else: + # build dict of constant values + args = [{args}] + all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, + configs = self._get_config(*all_args), + constants = self._make_constants(constexpr_key) + constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}}) + constants.update({{i: 1 for i in configs[0].equal_to_1}}) + # build kernel signature -- doesn't include specialized arguments + signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }} + # build stub signature -- includes arguments that are specialized + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {{i}} is not supported") + if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): + bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug, use_sme=use_sme) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args) + self.cache[device][key] = bin + return bin + return None +""" + scope = {"version_key": version_key(), "get_cuda_stream": get_cuda_stream, + "self": self, "_spec_of": self._spec_of, "_key_of": self._key_of, + "cache": self.cache, "triton": triton, + "get_current_device": get_current_device, + "set_current_device": set_current_device, + "get_corex_sme": get_corex_sme} + exec(src, scope) + return scope[self.fn.__name__] + + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None): + self.fn = fn + self.module = fn.__module__ + self.version = version + # function signature information + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + self.has_defaults = any(v.default != inspect._empty for v in signature.parameters.values()) + # specialization hints + self.do_not_specialize = [] if do_not_specialize is None else do_not_specialize + self.do_not_specialize = {self.arg_names.index(arg) if isinstance(arg, str) else arg for arg in self.do_not_specialize} + # function source code (without decorators) + self.src = textwrap.dedent(inspect.getsource(fn)) + self.src = self.src[self.src.find("def"):] + # cache of just-in-time compiled kernels + self.cache = defaultdict(dict) + self.hash = None + # JITFunction can be instantiated as kernel + # when called with a grid using __getitem__ + self.kernel_decorators = [] + self.kernel = None + self.debug = True if os.environ.get("TRITON_DEBUG", "0") == "1" else debug + self.noinline = noinline + # annotations + normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty + self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()} + # index of constexprs + self.constexprs = [self.arg_names.index(name) for name, ty in self.__annotations__.items() if 'constexpr' in ty] + # launcher + self.run = self._make_launcher() + # re-use docs of wrapped function + self.__doc__ = fn.__doc__ + self.__name__ = fn.__name__ + self.__globals__ = fn.__globals__ + self.__module__ = fn.__module__ + + @property + def cache_key(self): + # TODO : hash should be attribute of `self` + if self.hash is None: + dependencies_finder = DependenciesFinder(globals=self.__globals__, src=self.src) + dependencies_finder.visit(self.parse()) + self.hash = dependencies_finder.ret + version_key() + return self.hash + + def warmup(self, *args, **kwargs): + return self.run(*map(MockTensor.wrap_dtype, args), **kwargs, warmup=True) + + # we do not parse `src` in the constructor because + # the user might want to monkey-patch self.src dynamically. + # Our unit tests do this, for example. + def parse(self): + tree = ast.parse(self.src) + assert isinstance(tree, ast.Module) + assert len(tree.body) == 1 + assert isinstance(tree.body[0], ast.FunctionDef) + return tree + + def __call__(self, *args, **kwargs): + raise RuntimeError("Cannot call @triton.jit'd outside of the scope of a kernel") + + def __setattr__(self, name, value): + # - when kernel decorators change, cached kernel + # needs to be cleared + if name == 'kernel_decorators': + self.kernel = None + super(JITFunction, self).__setattr__(name, value) + # - when `.src` attribute is set, cache path needs + # to be reinitialized + if name == 'src': + self.hash = None + + def __repr__(self): + return f"JITFunction({self.module}:{self.fn.__name__})" + + +# ----------------------------------------------------------------------------- +# `jit` decorator +# ----------------------------------------------------------------------------- + + +@overload +def jit(fn: T) -> JITFunction[T]: + ... + + +@overload +def jit( + *, + version=None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, +) -> Callable[[T], JITFunction[T]]: + ... + + +def jit( + fn: Optional[T] = None, + *, + version=None, + do_not_specialize: Optional[Iterable[int]] = None, + debug: Optional[bool] = None, + noinline: Optional[bool] = None, + interpret: Optional[bool] = None, +) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: + """ + Decorator for JIT-compiling a function using the Triton compiler. + + :note: When a jit'd function is called, arguments are + implicitly converted to pointers if they have a :code:`.data_ptr()` method + and a `.dtype` attribute. + + :note: This function will be compiled and run on the GPU. It will only have access to: + + * python primitives, + * builtins within the triton package, + * arguments to this function, + * other jit'd functions + + :param fn: the function to be jit-compiled + :type fn: Callable + """ + + def decorator(fn: T) -> JITFunction[T]: + assert callable(fn) + if interpret: + from ..debugger.debugger import GridSelector + return GridSelector(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + ) + if fn is not None: + return decorator(fn) + + else: + return decorator + +# ----------------------------------------------------------------------------- +# Utilities for mocking tensors +# ----------------------------------------------------------------------------- + + +class MockTensor: + """ + Can be used in place of real tensors when calling: + kernel.warmup(MockTensor(torch.float32), ...) + """ + @staticmethod + def wrap_dtype(arg): + if arg.__class__.__name__ == "dtype" and\ + arg.__module__ == "torch": + return MockTensor(arg) + return arg + + def __init__(self, dtype): + self.dtype = dtype + + @staticmethod + def data_ptr(): + return 0 # optimistically assumes multiple of 16 + + +class TensorWrapper: + def __init__(self, base, dtype): + self.dtype = dtype + self.base = base + self.is_cuda = base.is_cuda + self.device = base.device + + def data_ptr(self): + return self.base.data_ptr() + + def __str__(self) -> str: + return f'TensorWrapper[{self.dtype}]({self.base})' + + +def reinterpret(tensor, dtype): + if isinstance(tensor, TensorWrapper): + if dtype == tensor.base.dtype: + # Reinterpreting to the original interpretation; return the base. + return tensor.base + else: + # Reinterpreting a wrapped tensor to a different type. + return TensorWrapper(tensor.base, dtype) + elif hasattr(tensor, "data_ptr"): + # A new wrapper is needed around an unwrapped tensor. + return TensorWrapper(tensor, dtype) + else: + raise TypeError(f'Cannot reinterpret a {type(tensor)}.') diff --git a/pkgs/triton/testing.py b/pkgs/triton/testing.py new file mode 100644 index 0000000..8ca1624 --- /dev/null +++ b/pkgs/triton/testing.py @@ -0,0 +1,424 @@ +import functools +import os +import subprocess +import sys +from contextlib import contextmanager + +import triton._C.libtriton.triton as _triton +from triton.common.build import is_corex + + +def nvsmi(attrs): + attrs = ','.join(attrs) + cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(',') + ret = [int(x) for x in ret] + return ret + + +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, + quantiles=None, + fast_flush=True, + return_mode="mean"): + assert return_mode in ["min", "max", "mean", "median"] + import torch + """ + Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with + the 20-th and 80-th performance percentile. + + :param fn: Function to benchmark + :type fn: Callable + :param warmup: Warmup time (in ms) + :type warmup: int + :param rep: Repetition time (in ms) + :type rep: int + :param grad_to_none: Reset the gradient of the provided tensor to None + :type grad_to_none: torch.tensor, optional + :param quantiles: Performance percentile to return in addition to the median. + :type quantiles: list[float] + :param fast_flush: Use faster kernel to flush L2 between measurements + :type fast_flush: bool + """ + + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() + + +def assert_close(x, y, atol=None, rtol=None, err_msg=''): + import numpy as np + import torch + + # canonicalize arguments to be tensors + if not isinstance(x, torch.Tensor): + x = torch.tensor(x) + if not isinstance(y, torch.Tensor): + y = torch.tensor(y) + # absolute tolerance + if atol is None: + atol = 1e-2 + atol = atol(x.dtype) if callable(atol) else atol + # relative tolerance hook + if rtol is None: + rtol = 0. + rtol = rtol(x.dtype) if callable(rtol) else rtol + # we use numpy instead of pytorch + # as it seems more memory efficient + # pytorch tends to oom on large tensors + if isinstance(x, torch.Tensor): + if x.dtype == torch.bfloat16: + x = x.float() + x = x.cpu().detach().numpy() + if isinstance(y, torch.Tensor): + if y.dtype == torch.bfloat16: + y = y.float() + y = y.cpu().detach().numpy() + # we handle size==1 case separately as we can + # provide better error message there + if x.size > 1 or y.size > 1: + np.testing.assert_allclose(x, y, atol=atol, rtol=rtol, equal_nan=True) + return + if not np.allclose(x, y, atol=atol, rtol=rtol): + raise AssertionError(f'{err_msg} {x} is not close to {y} (atol={atol}, rtol={rtol})') + + +class Benchmark: + """ + This class is used by the :code:`perf_report` function to generate line plots with a concise API. + """ + + def __init__( + self, + x_names, + x_vals, + line_arg, + line_vals, + line_names, + plot_name, + args, + xlabel='', + ylabel='', + x_log=False, + y_log=False, + color=None, + styles=None, + ): + """ + Constructor + + :param x_names: Name of the arguments that should appear on the x axis of the plot. If the list contains more than one element, all the arguments are assumed to have the same value. + :type x_names: List[str] + :param x_vals: List of values to use for the arguments in :code:`x_names`. + :type x_vals: List[Any] + :param line_arg: Argument name for which different values correspond to different lines in the plot. + :type line_arg: str + :param line_vals: List of values to use for the arguments in :code:`line_arg`. + :type line_vals: List[str] + :param line_names: Label names for the different lines. + :type line_names: List[str] + :param plot_name: Name of the plot. + :type plot_name: str + :param args: List of arguments to remain fixed throughout the benchmark. + :type args: List[str] + :param xlabel: Label for the x axis of the plot. + :type xlabel: str, optional + :param ylabel: Label for the y axis of the plot. + :type ylabel: str, optional + :param x_log: Whether the x axis should be log scale. + :type x_log: bool, optional + :param y_log: Whether the y axis should be log scale. + :type y_log: bool, optional + """ + self.x_names = x_names + self.x_vals = x_vals + self.x_log = x_log + self.line_arg = line_arg + self.line_vals = line_vals + self.line_names = line_names + self.y_log = y_log + self.styles = styles + # plot info + self.xlabel = xlabel + self.ylabel = ylabel + self.plot_name = plot_name + self.args = args + + +class Mark: + def __init__(self, fn, benchmarks): + self.fn = fn + self.benchmarks = benchmarks + + def _run(self, bench, save_path, show_plots, print_data): + import os + + import matplotlib.pyplot as plt + import pandas as pd + y_mean = bench.line_names + y_min = [f'{x}-min' for x in bench.line_names] + y_max = [f'{x}-max' for x in bench.line_names] + df = pd.DataFrame(columns=[bench.x_names[0]] + y_mean + y_min + y_max) + for x in bench.x_vals: + x_args = {x_name: x for x_name in bench.x_names} + row_mean, row_min, row_max = [], [], [] + for y in bench.line_vals: + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args) + try: + y_mean, y_min, y_max = ret + except TypeError: + y_mean, y_min, y_max = ret, None, None + row_mean += [y_mean] + row_min += [y_min] + row_max += [y_max] + df.loc[len(df)] = [x] + row_mean + row_min + row_max + if bench.plot_name: + plt.figure() + ax = plt.subplot() + x = bench.x_names[0] + for i, y in enumerate(bench.line_names): + y_min, y_max = df[y + '-min'], df[y + '-max'] + col = bench.styles[i][0] if bench.styles else None + sty = bench.styles[i][1] if bench.styles else None + ax.plot(df[x], df[y], label=y, color=col, ls=sty) + if y_min is not None and y_max is not None: + ax.fill_between(df[x], y_min, y_max, alpha=0.15, color=col) + ax.legend() + xlabel = bench.xlabel if bench.xlabel else " = ".join(bench.x_names) + ax.set_xlabel(xlabel) + ax.set_ylabel(bench.ylabel) + # ax.set_title(bench.plot_name) + ax.set_xscale("log" if bench.x_log else "linear") + ax.set_yscale("log" if bench.y_log else "linear") + if show_plots: + plt.show() + if save_path: + plt.savefig(os.path.join(save_path, f"{bench.plot_name}.png")) + df = df[[bench.x_names[0]] + bench.line_names] + if print_data: + print(bench.plot_name + ':') + print(df) + if save_path: + df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) + + def run(self, show_plots=False, print_data=False, save_path=''): + has_single_bench = isinstance(self.benchmarks, Benchmark) + benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks + if save_path: + html = open(os.path.join(save_path, "results.html"), "w") + html.write("\n") + for bench in benchmarks: + self._run(bench, save_path, show_plots, print_data) + if save_path: + html.write(f"\n") + if save_path: + html.write("\n") + + +def perf_report(benchmarks): + """ + Mark a function for benchmarking. The benchmark can then be executed by using the :code:`.run` method on the return value. + + :param benchmarks: Benchmarking configurations. + :type benchmarks: List of :class:`Benchmark` + """ + wrapper = lambda fn: Mark(fn, benchmarks) + return wrapper + + +def get_dram_gbps(backend=None, device=None): + ''' return DRAM bandwidth in GB/s ''' + import torch + + from .runtime import driver + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + mem_clock_khz = driver.utils.get_device_properties(device)["mem_clock_rate"] # in kHz + bus_width = driver.utils.get_device_properties(device)["mem_bus_width"] + bw_gbps = mem_clock_khz * bus_width * 2 / 1e6 / 8 # In GB/s + return bw_gbps + + +def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None): + import torch + + from .runtime import driver + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 + if not clock_rate: + clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8: + assert dtype == torch.float16 or is_corex() + ops_per_sub_core = 256 # 2 4x4x4 Tensor Cores + else: + if dtype == torch.float32: + ops_per_sub_core = 256 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 512 + elif dtype == torch.int8: + ops_per_sub_core = 1024 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops + +# create decorator that wraps test function into +# a cuda-memcheck system call + + +def cuda_memcheck(**target_kwargs): + def decorator(test_fn): + @functools.wraps(test_fn) + def wrapper(*args, **kwargs): + import psutil + ppid_name = psutil.Process(os.getppid()).name() + run_cuda_memcheck = target_kwargs.items() <= kwargs.items() + if run_cuda_memcheck and ppid_name != "cuda-memcheck": + path = os.path.realpath(test_fn.__globals__["__file__"]) + # get path of current file + env = {"PATH": os.environ["PATH"], "PYTORCH_NO_CUDA_MEMORY_CACHING": "1"} + assert 'request' in kwargs, "memcheck'ed test must have a (possibly unused) `request` fixture" + test_id = kwargs['request'].node.callspec.id + cmd = f"{path}::{test_fn.__name__}[{test_id}]" + out = subprocess.run(["cuda-memcheck", "pytest", "-vs", cmd], capture_output=True, env=env) + assert out.returncode == 0, "cuda-memcheck returned an error: bounds checking failed" + assert "ERROR SUMMARY: 0 errors" in str(out.stdout) + else: + test_fn(*args, **kwargs) + return wrapper + return decorator + + +def nvsmi_attr(attrs): + attrs = ",".join(attrs) + cmd = [ + "nvidia-smi", + "-i", + "0", + "--query-gpu=" + attrs, + "--format=csv,noheader,nounits", + ] + out = subprocess.check_output(cmd) + ret = out.decode(sys.stdout.encoding).split(",") + ret = [int(x) for x in ret] + return ret + + +@contextmanager +def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): + try: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) + subprocess.check_output( + [ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ] + ) + subprocess.check_output( + [ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ] + ) + cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0] + assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" + assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" + tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock + gbps = 640 * 2 * ref_mem_clock * 1e-3 + yield tflops, gbps + finally: + subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "0"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rgc"]) + subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) + + +def get_max_simd_tflops(dtype, backend=None, device=None): + import torch + + from .runtime import driver + if not backend: + backend = _triton.runtime.backend.CUDA + if not device: + device = torch.cuda.current_device() + + num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 + clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + if dtype == torch.float32: + ops_per_sub_core = 32 # 2*16 + elif dtype == torch.float16: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + else: + if dtype == torch.float32: + ops_per_sub_core = 32 + elif dtype in [torch.float16, torch.bfloat16]: + ops_per_sub_core = 64 + else: + raise RuntimeError("dtype not supported") + tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 + return tflops diff --git a/pkgs/triton/third_party/cuda/bin/ptxas b/pkgs/triton/third_party/cuda/bin/ptxas new file mode 100755 index 0000000..8b47936 Binary files /dev/null and b/pkgs/triton/third_party/cuda/bin/ptxas differ diff --git a/pkgs/triton/third_party/cuda/include/cuda.h b/pkgs/triton/third_party/cuda/include/cuda.h new file mode 100755 index 0000000..c713bf3 --- /dev/null +++ b/pkgs/triton/third_party/cuda/include/cuda.h @@ -0,0 +1,19348 @@ +/* + * Copyright 1993-2018 NVIDIA Corporation. All rights reserved. + * + * NOTICE TO LICENSEE: + * + * This source code and/or documentation ("Licensed Deliverables") are + * subject to NVIDIA intellectual property rights under U.S. and + * international Copyright laws. + * + * These Licensed Deliverables contained herein is PROPRIETARY and + * CONFIDENTIAL to NVIDIA and is being provided under the terms and + * conditions of a form of NVIDIA software license agreement by and + * between NVIDIA and Licensee ("License Agreement") or electronically + * accepted by Licensee. Notwithstanding any terms or conditions to + * the contrary in the License Agreement, reproduction or disclosure + * of the Licensed Deliverables to any third party without the express + * written consent of NVIDIA is prohibited. + * + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE + * SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS + * PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND. + * NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED + * DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY, + * NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE. + * NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE + * LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY + * SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY + * DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, + * WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS + * ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE + * OF THESE LICENSED DELIVERABLES. + * + * U.S. Government End Users. These Licensed Deliverables are a + * "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT + * 1995), consisting of "commercial computer software" and "commercial + * computer software documentation" as such terms are used in 48 + * C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government + * only as a commercial end item. Consistent with 48 C.F.R.12.212 and + * 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all + * U.S. Government End Users acquire the Licensed Deliverables with + * only those rights set forth herein. + * + * Any use of the Licensed Deliverables in individual and commercial + * software must include, in the user documentation and internal + * comments to the code, the above Disclaimer and U.S. Government End + * Users Notice. + */ + +#ifndef __cuda_cuda_h__ +#define __cuda_cuda_h__ + + + +#include +#ifdef _MSC_VER +typedef unsigned __int32 cuuint32_t; +typedef unsigned __int64 cuuint64_t; +#else +#include +typedef uint32_t cuuint32_t; +typedef uint64_t cuuint64_t; +#endif + +#if defined(__CUDA_API_VERSION_INTERNAL) || defined(__DOXYGEN_ONLY__) || defined(CUDA_ENABLE_DEPRECATED) +#define __CUDA_DEPRECATED +#elif defined(_MSC_VER) +#define __CUDA_DEPRECATED __declspec(deprecated) +#elif defined(__GNUC__) +#define __CUDA_DEPRECATED __attribute__((deprecated)) +#else +#define __CUDA_DEPRECATED +#endif + +#if defined(CUDA_FORCE_API_VERSION) +#error "CUDA_FORCE_API_VERSION is no longer supported." +#endif + +#if defined(__CUDA_API_VERSION_INTERNAL) || defined(CUDA_API_PER_THREAD_DEFAULT_STREAM) + #define __CUDA_API_PER_THREAD_DEFAULT_STREAM + #define __CUDA_API_PTDS(api) api ## _ptds + #define __CUDA_API_PTSZ(api) api ## _ptsz +#else + #define __CUDA_API_PTDS(api) api + #define __CUDA_API_PTSZ(api) api +#endif + +#define cuDeviceTotalMem cuDeviceTotalMem_v2 +#define cuCtxCreate cuCtxCreate_v2 +#define cuCtxCreate_v3 cuCtxCreate_v3 +#define cuModuleGetGlobal cuModuleGetGlobal_v2 +#define cuMemGetInfo cuMemGetInfo_v2 +#define cuMemAlloc cuMemAlloc_v2 +#define cuMemAllocPitch cuMemAllocPitch_v2 +#define cuMemFree cuMemFree_v2 +#define cuMemGetAddressRange cuMemGetAddressRange_v2 +#define cuMemAllocHost cuMemAllocHost_v2 +#define cuMemHostGetDevicePointer cuMemHostGetDevicePointer_v2 +#define cuMemcpyHtoD __CUDA_API_PTDS(cuMemcpyHtoD_v2) +#define cuMemcpyDtoH __CUDA_API_PTDS(cuMemcpyDtoH_v2) +#define cuMemcpyDtoD __CUDA_API_PTDS(cuMemcpyDtoD_v2) +#define cuMemcpyDtoA __CUDA_API_PTDS(cuMemcpyDtoA_v2) +#define cuMemcpyAtoD __CUDA_API_PTDS(cuMemcpyAtoD_v2) +#define cuMemcpyHtoA __CUDA_API_PTDS(cuMemcpyHtoA_v2) +#define cuMemcpyAtoH __CUDA_API_PTDS(cuMemcpyAtoH_v2) +#define cuMemcpyAtoA __CUDA_API_PTDS(cuMemcpyAtoA_v2) +#define cuMemcpyHtoAAsync __CUDA_API_PTSZ(cuMemcpyHtoAAsync_v2) +#define cuMemcpyAtoHAsync __CUDA_API_PTSZ(cuMemcpyAtoHAsync_v2) +#define cuMemcpy2D __CUDA_API_PTDS(cuMemcpy2D_v2) +#define cuMemcpy2DUnaligned __CUDA_API_PTDS(cuMemcpy2DUnaligned_v2) +#define cuMemcpy3D __CUDA_API_PTDS(cuMemcpy3D_v2) +#define cuMemcpyHtoDAsync __CUDA_API_PTSZ(cuMemcpyHtoDAsync_v2) +#define cuMemcpyDtoHAsync __CUDA_API_PTSZ(cuMemcpyDtoHAsync_v2) +#define cuMemcpyDtoDAsync __CUDA_API_PTSZ(cuMemcpyDtoDAsync_v2) +#define cuMemcpy2DAsync __CUDA_API_PTSZ(cuMemcpy2DAsync_v2) +#define cuMemcpy3DAsync __CUDA_API_PTSZ(cuMemcpy3DAsync_v2) +#define cuMemsetD8 __CUDA_API_PTDS(cuMemsetD8_v2) +#define cuMemsetD16 __CUDA_API_PTDS(cuMemsetD16_v2) +#define cuMemsetD32 __CUDA_API_PTDS(cuMemsetD32_v2) +#define cuMemsetD2D8 __CUDA_API_PTDS(cuMemsetD2D8_v2) +#define cuMemsetD2D16 __CUDA_API_PTDS(cuMemsetD2D16_v2) +#define cuMemsetD2D32 __CUDA_API_PTDS(cuMemsetD2D32_v2) +#define cuArrayCreate cuArrayCreate_v2 +#define cuArrayGetDescriptor cuArrayGetDescriptor_v2 +#define cuArray3DCreate cuArray3DCreate_v2 +#define cuArray3DGetDescriptor cuArray3DGetDescriptor_v2 +#define cuTexRefSetAddress cuTexRefSetAddress_v2 +#define cuTexRefGetAddress cuTexRefGetAddress_v2 +#define cuGraphicsResourceGetMappedPointer cuGraphicsResourceGetMappedPointer_v2 +#define cuCtxDestroy cuCtxDestroy_v2 +#define cuCtxPopCurrent cuCtxPopCurrent_v2 +#define cuCtxPushCurrent cuCtxPushCurrent_v2 +#define cuStreamDestroy cuStreamDestroy_v2 +#define cuEventDestroy cuEventDestroy_v2 +#define cuTexRefSetAddress2D cuTexRefSetAddress2D_v3 +#define cuLinkCreate cuLinkCreate_v2 +#define cuLinkAddData cuLinkAddData_v2 +#define cuLinkAddFile cuLinkAddFile_v2 +#define cuMemHostRegister cuMemHostRegister_v2 +#define cuGraphicsResourceSetMapFlags cuGraphicsResourceSetMapFlags_v2 +#define cuStreamBeginCapture __CUDA_API_PTSZ(cuStreamBeginCapture_v2) +#define cuDevicePrimaryCtxRelease cuDevicePrimaryCtxRelease_v2 +#define cuDevicePrimaryCtxReset cuDevicePrimaryCtxReset_v2 +#define cuDevicePrimaryCtxSetFlags cuDevicePrimaryCtxSetFlags_v2 +#define cuDeviceGetUuid_v2 cuDeviceGetUuid_v2 +#define cuIpcOpenMemHandle cuIpcOpenMemHandle_v2 +#define cuGraphInstantiate cuGraphInstantiate_v2 + +#if defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) + #define cuMemcpy __CUDA_API_PTDS(cuMemcpy) + #define cuMemcpyAsync __CUDA_API_PTSZ(cuMemcpyAsync) + #define cuMemcpyPeer __CUDA_API_PTDS(cuMemcpyPeer) + #define cuMemcpyPeerAsync __CUDA_API_PTSZ(cuMemcpyPeerAsync) + #define cuMemcpy3DPeer __CUDA_API_PTDS(cuMemcpy3DPeer) + #define cuMemcpy3DPeerAsync __CUDA_API_PTSZ(cuMemcpy3DPeerAsync) + #define cuMemPrefetchAsync __CUDA_API_PTSZ(cuMemPrefetchAsync) + + #define cuMemsetD8Async __CUDA_API_PTSZ(cuMemsetD8Async) + #define cuMemsetD16Async __CUDA_API_PTSZ(cuMemsetD16Async) + #define cuMemsetD32Async __CUDA_API_PTSZ(cuMemsetD32Async) + #define cuMemsetD2D8Async __CUDA_API_PTSZ(cuMemsetD2D8Async) + #define cuMemsetD2D16Async __CUDA_API_PTSZ(cuMemsetD2D16Async) + #define cuMemsetD2D32Async __CUDA_API_PTSZ(cuMemsetD2D32Async) + + #define cuStreamGetPriority __CUDA_API_PTSZ(cuStreamGetPriority) + #define cuStreamGetFlags __CUDA_API_PTSZ(cuStreamGetFlags) + #define cuStreamGetCtx __CUDA_API_PTSZ(cuStreamGetCtx) + #define cuStreamWaitEvent __CUDA_API_PTSZ(cuStreamWaitEvent) + #define cuStreamEndCapture __CUDA_API_PTSZ(cuStreamEndCapture) + #define cuStreamIsCapturing __CUDA_API_PTSZ(cuStreamIsCapturing) + #define cuStreamGetCaptureInfo __CUDA_API_PTSZ(cuStreamGetCaptureInfo) + #define cuStreamGetCaptureInfo_v2 __CUDA_API_PTSZ(cuStreamGetCaptureInfo_v2) + #define cuStreamUpdateCaptureDependencies __CUDA_API_PTSZ(cuStreamUpdateCaptureDependencies) + #define cuStreamAddCallback __CUDA_API_PTSZ(cuStreamAddCallback) + #define cuStreamAttachMemAsync __CUDA_API_PTSZ(cuStreamAttachMemAsync) + #define cuStreamQuery __CUDA_API_PTSZ(cuStreamQuery) + #define cuStreamSynchronize __CUDA_API_PTSZ(cuStreamSynchronize) + #define cuEventRecord __CUDA_API_PTSZ(cuEventRecord) + #define cuEventRecordWithFlags __CUDA_API_PTSZ(cuEventRecordWithFlags) + #define cuLaunchKernel __CUDA_API_PTSZ(cuLaunchKernel) + + + + #define cuLaunchHostFunc __CUDA_API_PTSZ(cuLaunchHostFunc) + #define cuGraphicsMapResources __CUDA_API_PTSZ(cuGraphicsMapResources) + #define cuGraphicsUnmapResources __CUDA_API_PTSZ(cuGraphicsUnmapResources) + + #define cuStreamWriteValue32 __CUDA_API_PTSZ(cuStreamWriteValue32) + #define cuStreamWaitValue32 __CUDA_API_PTSZ(cuStreamWaitValue32) + #define cuStreamWriteValue64 __CUDA_API_PTSZ(cuStreamWriteValue64) + #define cuStreamWaitValue64 __CUDA_API_PTSZ(cuStreamWaitValue64) + #define cuStreamBatchMemOp __CUDA_API_PTSZ(cuStreamBatchMemOp) + + #define cuLaunchCooperativeKernel __CUDA_API_PTSZ(cuLaunchCooperativeKernel) + + #define cuSignalExternalSemaphoresAsync __CUDA_API_PTSZ(cuSignalExternalSemaphoresAsync) + #define cuWaitExternalSemaphoresAsync __CUDA_API_PTSZ(cuWaitExternalSemaphoresAsync) + + #define cuGraphUpload __CUDA_API_PTSZ(cuGraphUpload) + #define cuGraphLaunch __CUDA_API_PTSZ(cuGraphLaunch) + #define cuStreamCopyAttributes __CUDA_API_PTSZ(cuStreamCopyAttributes) + #define cuStreamGetAttribute __CUDA_API_PTSZ(cuStreamGetAttribute) + #define cuStreamSetAttribute __CUDA_API_PTSZ(cuStreamSetAttribute) + #define cuMemMapArrayAsync __CUDA_API_PTSZ(cuMemMapArrayAsync) + + #define cuMemFreeAsync __CUDA_API_PTSZ(cuMemFreeAsync) + #define cuMemAllocAsync __CUDA_API_PTSZ(cuMemAllocAsync) + #define cuMemAllocFromPoolAsync __CUDA_API_PTSZ(cuMemAllocFromPoolAsync) +#endif + +/** + * \file cuda.h + * \brief Header file for the CUDA Toolkit application programming interface. + * + * \file cudaGL.h + * \brief Header file for the OpenGL interoperability functions of the + * low-level CUDA driver application programming interface. + * + * \file cudaD3D9.h + * \brief Header file for the Direct3D 9 interoperability functions of the + * low-level CUDA driver application programming interface. + */ + +/** + * \defgroup CUDA_TYPES Data types used by CUDA driver + * @{ + */ + +/** + * CUDA API version number + */ +#define CUDA_VERSION 11060 + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * CUDA device pointer + * CUdeviceptr is defined as an unsigned integer type whose size matches the size of a pointer on the target platform. + */ +#if defined(_WIN64) || defined(__LP64__) +typedef unsigned long long CUdeviceptr_v2; +#else +typedef unsigned int CUdeviceptr_v2; +#endif +typedef CUdeviceptr_v2 CUdeviceptr; /**< CUDA device pointer */ + +typedef int CUdevice_v1; /**< CUDA device */ +typedef CUdevice_v1 CUdevice; /**< CUDA device */ +typedef struct CUctx_st *CUcontext; /**< CUDA context */ +typedef struct CUmod_st *CUmodule; /**< CUDA module */ +typedef struct CUfunc_st *CUfunction; /**< CUDA function */ +typedef struct CUarray_st *CUarray; /**< CUDA array */ +typedef struct CUmipmappedArray_st *CUmipmappedArray; /**< CUDA mipmapped array */ +typedef struct CUtexref_st *CUtexref; /**< CUDA texture reference */ +typedef struct CUsurfref_st *CUsurfref; /**< CUDA surface reference */ +typedef struct CUevent_st *CUevent; /**< CUDA event */ +typedef struct CUstream_st *CUstream; /**< CUDA stream */ +typedef struct CUgraphicsResource_st *CUgraphicsResource; /**< CUDA graphics interop resource */ +typedef unsigned long long CUtexObject_v1; /**< An opaque value that represents a CUDA texture object */ +typedef CUtexObject_v1 CUtexObject; /**< An opaque value that represents a CUDA texture object */ +typedef unsigned long long CUsurfObject_v1; /**< An opaque value that represents a CUDA surface object */ +typedef CUsurfObject_v1 CUsurfObject; /**< An opaque value that represents a CUDA surface object */ +typedef struct CUextMemory_st *CUexternalMemory; /**< CUDA external memory */ +typedef struct CUextSemaphore_st *CUexternalSemaphore; /**< CUDA external semaphore */ +typedef struct CUgraph_st *CUgraph; /**< CUDA graph */ +typedef struct CUgraphNode_st *CUgraphNode; /**< CUDA graph node */ +typedef struct CUgraphExec_st *CUgraphExec; /**< CUDA executable graph */ +typedef struct CUmemPoolHandle_st *CUmemoryPool; /**< CUDA memory pool */ +typedef struct CUuserObject_st *CUuserObject; /**< CUDA user object for graphs */ + +#ifndef CU_UUID_HAS_BEEN_DEFINED +#define CU_UUID_HAS_BEEN_DEFINED +typedef struct CUuuid_st { /**< CUDA definition of UUID */ + char bytes[16]; +} CUuuid; +#endif + +/** + * CUDA IPC handle size + */ +#define CU_IPC_HANDLE_SIZE 64 + +/** + * CUDA IPC event handle + */ +typedef struct CUipcEventHandle_st { + char reserved[CU_IPC_HANDLE_SIZE]; +} CUipcEventHandle_v1; +typedef CUipcEventHandle_v1 CUipcEventHandle; + +/** + * CUDA IPC mem handle + */ +typedef struct CUipcMemHandle_st { + char reserved[CU_IPC_HANDLE_SIZE]; +} CUipcMemHandle_v1; +typedef CUipcMemHandle_v1 CUipcMemHandle; + +/** + * CUDA Ipc Mem Flags + */ +typedef enum CUipcMem_flags_enum { + CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS = 0x1 /**< Automatically enable peer access between remote devices as needed */ +} CUipcMem_flags; + + +/** + * CUDA Mem Attach Flags + */ +typedef enum CUmemAttach_flags_enum { + CU_MEM_ATTACH_GLOBAL = 0x1, /**< Memory can be accessed by any stream on any device */ + CU_MEM_ATTACH_HOST = 0x2, /**< Memory cannot be accessed by any stream on any device */ + CU_MEM_ATTACH_SINGLE = 0x4 /**< Memory can only be accessed by a single stream on the associated device */ +} CUmemAttach_flags; + +/** + * Context creation flags + */ +typedef enum CUctx_flags_enum { + CU_CTX_SCHED_AUTO = 0x00, /**< Automatic scheduling */ + CU_CTX_SCHED_SPIN = 0x01, /**< Set spin as default scheduling */ + CU_CTX_SCHED_YIELD = 0x02, /**< Set yield as default scheduling */ + CU_CTX_SCHED_BLOCKING_SYNC = 0x04, /**< Set blocking synchronization as default scheduling */ + CU_CTX_BLOCKING_SYNC = 0x04, /**< Set blocking synchronization as default scheduling + * \deprecated This flag was deprecated as of CUDA 4.0 + * and was replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. */ + CU_CTX_SCHED_MASK = 0x07, + CU_CTX_MAP_HOST = 0x08, /**< \deprecated This flag was deprecated as of CUDA 11.0 + * and it no longer has any effect. All contexts + * as of CUDA 3.2 behave as though the flag is enabled. */ + CU_CTX_LMEM_RESIZE_TO_MAX = 0x10, /**< Keep local memory allocation after launch */ + CU_CTX_FLAGS_MASK = 0x1f +} CUctx_flags; + +/** + * Stream creation flags + */ +typedef enum CUstream_flags_enum { + CU_STREAM_DEFAULT = 0x0, /**< Default stream flag */ + CU_STREAM_NON_BLOCKING = 0x1 /**< Stream does not synchronize with stream 0 (the NULL stream) */ +} CUstream_flags; + +/** + * Legacy stream handle + * + * Stream handle that can be passed as a CUstream to use an implicit stream + * with legacy synchronization behavior. + * + * See details of the \link_sync_behavior + */ +#define CU_STREAM_LEGACY ((CUstream)0x1) + +/** + * Per-thread stream handle + * + * Stream handle that can be passed as a CUstream to use an implicit stream + * with per-thread synchronization behavior. + * + * See details of the \link_sync_behavior + */ +#define CU_STREAM_PER_THREAD ((CUstream)0x2) + +/** + * Event creation flags + */ +typedef enum CUevent_flags_enum { + CU_EVENT_DEFAULT = 0x0, /**< Default event flag */ + CU_EVENT_BLOCKING_SYNC = 0x1, /**< Event uses blocking synchronization */ + CU_EVENT_DISABLE_TIMING = 0x2, /**< Event will not record timing data */ + CU_EVENT_INTERPROCESS = 0x4 /**< Event is suitable for interprocess use. CU_EVENT_DISABLE_TIMING must be set */ +} CUevent_flags; + +/** + * Event record flags + */ +typedef enum CUevent_record_flags_enum { + CU_EVENT_RECORD_DEFAULT = 0x0, /**< Default event record flag */ + CU_EVENT_RECORD_EXTERNAL = 0x1 /**< When using stream capture, create an event record node + * instead of the default behavior. This flag is invalid + * when used outside of capture. */ +} CUevent_record_flags; + +/** + * Event wait flags + */ +typedef enum CUevent_wait_flags_enum { + CU_EVENT_WAIT_DEFAULT = 0x0, /**< Default event wait flag */ + CU_EVENT_WAIT_EXTERNAL = 0x1 /**< When using stream capture, create an event wait node + * instead of the default behavior. This flag is invalid + * when used outside of capture.*/ +} CUevent_wait_flags; + +/** + * Flags for ::cuStreamWaitValue32 and ::cuStreamWaitValue64 + */ +typedef enum CUstreamWaitValue_flags_enum { + CU_STREAM_WAIT_VALUE_GEQ = 0x0, /**< Wait until (int32_t)(*addr - value) >= 0 (or int64_t for 64 bit + values). Note this is a cyclic comparison which ignores wraparound. + (Default behavior.) */ + CU_STREAM_WAIT_VALUE_EQ = 0x1, /**< Wait until *addr == value. */ + CU_STREAM_WAIT_VALUE_AND = 0x2, /**< Wait until (*addr & value) != 0. */ + CU_STREAM_WAIT_VALUE_NOR = 0x3, /**< Wait until ~(*addr | value) != 0. Support for this operation can be + queried with ::cuDeviceGetAttribute() and + ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR.*/ + CU_STREAM_WAIT_VALUE_FLUSH = 1<<30 /**< Follow the wait operation with a flush of outstanding remote writes. This + means that, if a remote write operation is guaranteed to have reached the + device before the wait can be satisfied, that write is guaranteed to be + visible to downstream device work. The device is permitted to reorder + remote writes internally. For example, this flag would be required if + two remote writes arrive in a defined order, the wait is satisfied by the + second write, and downstream work needs to observe the first write. + Support for this operation is restricted to selected platforms and can be + queried with ::CU_DEVICE_ATTRIBUTE_CAN_USE_WAIT_VALUE_FLUSH.*/ +} CUstreamWaitValue_flags; + +/** + * Flags for ::cuStreamWriteValue32 + */ +typedef enum CUstreamWriteValue_flags_enum { + CU_STREAM_WRITE_VALUE_DEFAULT = 0x0, /**< Default behavior */ + CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER = 0x1 /**< Permits the write to be reordered with writes which were issued + before it, as a performance optimization. Normally, + ::cuStreamWriteValue32 will provide a memory fence before the + write, which has similar semantics to + __threadfence_system() but is scoped to the stream + rather than a CUDA thread. */ +} CUstreamWriteValue_flags; + +/** + * Operations for ::cuStreamBatchMemOp + */ +typedef enum CUstreamBatchMemOpType_enum { + CU_STREAM_MEM_OP_WAIT_VALUE_32 = 1, /**< Represents a ::cuStreamWaitValue32 operation */ + CU_STREAM_MEM_OP_WRITE_VALUE_32 = 2, /**< Represents a ::cuStreamWriteValue32 operation */ + CU_STREAM_MEM_OP_WAIT_VALUE_64 = 4, /**< Represents a ::cuStreamWaitValue64 operation */ + CU_STREAM_MEM_OP_WRITE_VALUE_64 = 5, /**< Represents a ::cuStreamWriteValue64 operation */ + CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES = 3 /**< This has the same effect as ::CU_STREAM_WAIT_VALUE_FLUSH, but as a + standalone operation. */ +} CUstreamBatchMemOpType; + +/** + * Per-operation parameters for ::cuStreamBatchMemOp + */ +typedef union CUstreamBatchMemOpParams_union { + CUstreamBatchMemOpType operation; + struct CUstreamMemOpWaitValueParams_st { + CUstreamBatchMemOpType operation; + CUdeviceptr address; + union { + cuuint32_t value; + cuuint64_t value64; + }; + unsigned int flags; + CUdeviceptr alias; /**< For driver internal use. Initial value is unimportant. */ + } waitValue; + struct CUstreamMemOpWriteValueParams_st { + CUstreamBatchMemOpType operation; + CUdeviceptr address; + union { + cuuint32_t value; + cuuint64_t value64; + }; + unsigned int flags; + CUdeviceptr alias; /**< For driver internal use. Initial value is unimportant. */ + } writeValue; + struct CUstreamMemOpFlushRemoteWritesParams_st { + CUstreamBatchMemOpType operation; + unsigned int flags; + } flushRemoteWrites; + cuuint64_t pad[6]; +} CUstreamBatchMemOpParams_v1; +typedef CUstreamBatchMemOpParams_v1 CUstreamBatchMemOpParams; + +/** + * Occupancy calculator flag + */ +typedef enum CUoccupancy_flags_enum { + CU_OCCUPANCY_DEFAULT = 0x0, /**< Default behavior */ + CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE = 0x1 /**< Assume global caching is enabled and cannot be automatically turned off */ +} CUoccupancy_flags; + +/** + * Flags for ::cuStreamUpdateCaptureDependencies + */ +typedef enum CUstreamUpdateCaptureDependencies_flags_enum { + CU_STREAM_ADD_CAPTURE_DEPENDENCIES = 0x0, /**< Add new nodes to the dependency set */ + CU_STREAM_SET_CAPTURE_DEPENDENCIES = 0x1 /**< Replace the dependency set with the new nodes */ +} CUstreamUpdateCaptureDependencies_flags; + +/** + * Array formats + */ +typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, /**< Unsigned 8-bit integers */ + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, /**< Unsigned 16-bit integers */ + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, /**< Unsigned 32-bit integers */ + CU_AD_FORMAT_SIGNED_INT8 = 0x08, /**< Signed 8-bit integers */ + CU_AD_FORMAT_SIGNED_INT16 = 0x09, /**< Signed 16-bit integers */ + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, /**< Signed 32-bit integers */ + CU_AD_FORMAT_HALF = 0x10, /**< 16-bit floating point */ + CU_AD_FORMAT_FLOAT = 0x20, /**< 32-bit floating point */ + CU_AD_FORMAT_NV12 = 0xb0, /**< 8-bit YUV planar format, with 4:2:0 sampling */ + CU_AD_FORMAT_UNORM_INT8X1 = 0xc0, /**< 1 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT8X2 = 0xc1, /**< 2 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT8X4 = 0xc2, /**< 4 channel unsigned 8-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X1 = 0xc3, /**< 1 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X2 = 0xc4, /**< 2 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_UNORM_INT16X4 = 0xc5, /**< 4 channel unsigned 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X1 = 0xc6, /**< 1 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X2 = 0xc7, /**< 2 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT8X4 = 0xc8, /**< 4 channel signed 8-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X1 = 0xc9, /**< 1 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X2 = 0xca, /**< 2 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_SNORM_INT16X4 = 0xcb, /**< 4 channel signed 16-bit normalized integer */ + CU_AD_FORMAT_BC1_UNORM = 0x91, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format */ + CU_AD_FORMAT_BC1_UNORM_SRGB = 0x92, /**< 4 channel unsigned normalized block-compressed (BC1 compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC2_UNORM = 0x93, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format */ + CU_AD_FORMAT_BC2_UNORM_SRGB = 0x94, /**< 4 channel unsigned normalized block-compressed (BC2 compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC3_UNORM = 0x95, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format */ + CU_AD_FORMAT_BC3_UNORM_SRGB = 0x96, /**< 4 channel unsigned normalized block-compressed (BC3 compression) format with sRGB encoding*/ + CU_AD_FORMAT_BC4_UNORM = 0x97, /**< 1 channel unsigned normalized block-compressed (BC4 compression) format */ + CU_AD_FORMAT_BC4_SNORM = 0x98, /**< 1 channel signed normalized block-compressed (BC4 compression) format */ + CU_AD_FORMAT_BC5_UNORM = 0x99, /**< 2 channel unsigned normalized block-compressed (BC5 compression) format */ + CU_AD_FORMAT_BC5_SNORM = 0x9a, /**< 2 channel signed normalized block-compressed (BC5 compression) format */ + CU_AD_FORMAT_BC6H_UF16 = 0x9b, /**< 3 channel unsigned half-float block-compressed (BC6H compression) format */ + CU_AD_FORMAT_BC6H_SF16 = 0x9c, /**< 3 channel signed half-float block-compressed (BC6H compression) format */ + CU_AD_FORMAT_BC7_UNORM = 0x9d, /**< 4 channel unsigned normalized block-compressed (BC7 compression) format */ + CU_AD_FORMAT_BC7_UNORM_SRGB = 0x9e /**< 4 channel unsigned normalized block-compressed (BC7 compression) format with sRGB encoding */ +} CUarray_format; + +/** + * Texture reference addressing modes + */ +typedef enum CUaddress_mode_enum { + CU_TR_ADDRESS_MODE_WRAP = 0, /**< Wrapping address mode */ + CU_TR_ADDRESS_MODE_CLAMP = 1, /**< Clamp to edge address mode */ + CU_TR_ADDRESS_MODE_MIRROR = 2, /**< Mirror address mode */ + CU_TR_ADDRESS_MODE_BORDER = 3 /**< Border address mode */ +} CUaddress_mode; + +/** + * Texture reference filtering modes + */ +typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, /**< Point filter mode */ + CU_TR_FILTER_MODE_LINEAR = 1 /**< Linear filter mode */ +} CUfilter_mode; + +/** + * Device properties + */ +typedef enum CUdevice_attribute_enum { + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 1, /**< Maximum number of threads per block */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X = 2, /**< Maximum block dimension X */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y = 3, /**< Maximum block dimension Y */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z = 4, /**< Maximum block dimension Z */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X = 5, /**< Maximum grid dimension X */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y = 6, /**< Maximum grid dimension Y */ + CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z = 7, /**< Maximum grid dimension Z */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK = 8, /**< Maximum shared memory available per block in bytes */ + CU_DEVICE_ATTRIBUTE_SHARED_MEMORY_PER_BLOCK = 8, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK */ + CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY = 9, /**< Memory available on device for __constant__ variables in a CUDA C kernel in bytes */ + CU_DEVICE_ATTRIBUTE_WARP_SIZE = 10, /**< Warp size in threads */ + CU_DEVICE_ATTRIBUTE_MAX_PITCH = 11, /**< Maximum pitch in bytes allowed by memory copies */ + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK = 12, /**< Maximum number of 32-bit registers available per block */ + CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK = 12, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK */ + CU_DEVICE_ATTRIBUTE_CLOCK_RATE = 13, /**< Typical clock frequency in kilohertz */ + CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT = 14, /**< Alignment requirement for textures */ + CU_DEVICE_ATTRIBUTE_GPU_OVERLAP = 15, /**< Device can possibly copy memory and execute a kernel concurrently. Deprecated. Use instead CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT. */ + CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT = 16, /**< Number of multiprocessors on device */ + CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT = 17, /**< Specifies whether there is a run time limit on kernels */ + CU_DEVICE_ATTRIBUTE_INTEGRATED = 18, /**< Device is integrated with host memory */ + CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY = 19, /**< Device can map host memory into CUDA address space */ + CU_DEVICE_ATTRIBUTE_COMPUTE_MODE = 20, /**< Compute mode (See ::CUcomputemode for details) */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH = 21, /**< Maximum 1D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH = 22, /**< Maximum 2D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT = 23, /**< Maximum 2D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH = 24, /**< Maximum 3D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT = 25, /**< Maximum 3D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH = 26, /**< Maximum 3D texture depth */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH = 27, /**< Maximum 2D layered texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT = 28, /**< Maximum 2D layered texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS = 29, /**< Maximum layers in a 2D layered texture */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_WIDTH = 27, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_HEIGHT = 28, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_ARRAY_NUMSLICES = 29, /**< Deprecated, use CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS */ + CU_DEVICE_ATTRIBUTE_SURFACE_ALIGNMENT = 30, /**< Alignment requirement for surfaces */ + CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS = 31, /**< Device can possibly execute multiple kernels concurrently */ + CU_DEVICE_ATTRIBUTE_ECC_ENABLED = 32, /**< Device has ECC support enabled */ + CU_DEVICE_ATTRIBUTE_PCI_BUS_ID = 33, /**< PCI bus ID of the device */ + CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID = 34, /**< PCI device ID of the device */ + CU_DEVICE_ATTRIBUTE_TCC_DRIVER = 35, /**< Device is using TCC driver model */ + CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE = 36, /**< Peak memory clock frequency in kilohertz */ + CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH = 37, /**< Global memory bus width in bits */ + CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE = 38, /**< Size of L2 cache in bytes */ + CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR = 39, /**< Maximum resident threads per multiprocessor */ + CU_DEVICE_ATTRIBUTE_ASYNC_ENGINE_COUNT = 40, /**< Number of asynchronous engines */ + CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING = 41, /**< Device shares a unified address space with the host */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH = 42, /**< Maximum 1D layered texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS = 43, /**< Maximum layers in a 1D layered texture */ + CU_DEVICE_ATTRIBUTE_CAN_TEX2D_GATHER = 44, /**< Deprecated, do not use. */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH = 45, /**< Maximum 2D texture width if CUDA_ARRAY3D_TEXTURE_GATHER is set */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT = 46, /**< Maximum 2D texture height if CUDA_ARRAY3D_TEXTURE_GATHER is set */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE = 47, /**< Alternate maximum 3D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE = 48, /**< Alternate maximum 3D texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE = 49, /**< Alternate maximum 3D texture depth */ + CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID = 50, /**< PCI domain ID of the device */ + CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT = 51, /**< Pitch alignment requirement for textures */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH = 52, /**< Maximum cubemap texture width/height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH = 53, /**< Maximum cubemap layered texture width/height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS = 54, /**< Maximum layers in a cubemap layered texture */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH = 55, /**< Maximum 1D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH = 56, /**< Maximum 2D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT = 57, /**< Maximum 2D surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH = 58, /**< Maximum 3D surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT = 59, /**< Maximum 3D surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH = 60, /**< Maximum 3D surface depth */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH = 61, /**< Maximum 1D layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS = 62, /**< Maximum layers in a 1D layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH = 63, /**< Maximum 2D layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT = 64, /**< Maximum 2D layered surface height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS = 65, /**< Maximum layers in a 2D layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH = 66, /**< Maximum cubemap surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH = 67, /**< Maximum cubemap layered surface width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS = 68, /**< Maximum layers in a cubemap layered surface */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH = 69, /**< Deprecated, do not use. Use cudaDeviceGetTexture1DLinearMaxWidth() or cuDeviceGetTexture1DLinearMaxWidth() instead. */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH = 70, /**< Maximum 2D linear texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT = 71, /**< Maximum 2D linear texture height */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH = 72, /**< Maximum 2D linear texture pitch in bytes */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH = 73, /**< Maximum mipmapped 2D texture width */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT = 74, /**< Maximum mipmapped 2D texture height */ + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR = 75, /**< Major compute capability version number */ + CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR = 76, /**< Minor compute capability version number */ + CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH = 77, /**< Maximum mipmapped 1D texture width */ + CU_DEVICE_ATTRIBUTE_STREAM_PRIORITIES_SUPPORTED = 78, /**< Device supports stream priorities */ + CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED = 79, /**< Device supports caching globals in L1 */ + CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED = 80, /**< Device supports caching locals in L1 */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR = 81, /**< Maximum shared memory available per multiprocessor in bytes */ + CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR = 82, /**< Maximum number of 32-bit registers available per multiprocessor */ + CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY = 83, /**< Device can allocate managed memory on this system */ + CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD = 84, /**< Device is on a multi-GPU board */ + CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID = 85, /**< Unique id for a group of devices on the same multi-GPU board */ + CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED = 86, /**< Link between the device and the host supports native atomic operations (this is a placeholder attribute, and is not supported on any current hardware)*/ + CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO = 87, /**< Ratio of single precision performance (in floating-point operations per second) to double precision performance */ + CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS = 88, /**< Device supports coherently accessing pageable memory without calling cudaHostRegister on it */ + CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS = 89, /**< Device can coherently access managed memory concurrently with the CPU */ + CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED = 90, /**< Device supports compute preemption. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM = 91, /**< Device can access host registered memory at the same virtual address as the CPU */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS = 92, /**< ::cuStreamBatchMemOp and related APIs are supported. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS = 93, /**< 64-bit operations are supported in ::cuStreamBatchMemOp and related APIs. */ + CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR = 94, /**< ::CU_STREAM_WAIT_VALUE_NOR is supported. */ + CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH = 95, /**< Device supports launching cooperative kernels via ::cuLaunchCooperativeKernel */ + CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH = 96, /**< Deprecated, ::cuLaunchCooperativeKernelMultiDevice is deprecated. */ + CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = 97, /**< Maximum optin shared memory per block */ + CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES = 98, /**< The ::CU_STREAM_WAIT_VALUE_FLUSH flag and the ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on the device. See \ref CUDA_MEMOP for additional details. */ + CU_DEVICE_ATTRIBUTE_HOST_REGISTER_SUPPORTED = 99, /**< Device supports host memory registration via ::cudaHostRegister. */ + CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES = 100, /**< Device accesses pageable memory via the host's page tables. */ + CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST = 101, /**< The host can directly access managed memory on the device without migration. */ + CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED = 102, /**< Deprecated, Use CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED*/ + CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED = 102, /**< Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED = 103, /**< Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED = 104, /**< Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED = 105, /**< Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR = 106, /**< Maximum number of blocks per multiprocessor */ + CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED = 107, /**< Device supports compression of memory */ + CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE = 108, /**< Maximum L2 persisting lines capacity setting in bytes. */ + CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE = 109, /**< Maximum value of CUaccessPolicyWindow::num_bytes. */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED = 110, /**< Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate */ + CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK = 111, /**< Shared memory reserved by CUDA driver per block in bytes */ + CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED = 112, /**< Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays */ + CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED = 113, /**< Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU */ + CU_DEVICE_ATTRIBUTE_TIMELINE_SEMAPHORE_INTEROP_SUPPORTED = 114, /**< External timeline semaphore interop is supported on the device */ + CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED = 115, /**< Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED = 116, /**< Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS = 117, /**< The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum */ + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING = 118, /**< GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. */ + CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES = 119, /**< Handle types supported with mempool based IPC */ + + + + + CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED = 121, /**< Device supports deferred mapping CUDA arrays and CUDA mipmapped arrays */ + + CU_DEVICE_ATTRIBUTE_MAX +} CUdevice_attribute; + +/** + * Legacy device properties + */ +typedef struct CUdevprop_st { + int maxThreadsPerBlock; /**< Maximum number of threads per block */ + int maxThreadsDim[3]; /**< Maximum size of each dimension of a block */ + int maxGridSize[3]; /**< Maximum size of each dimension of a grid */ + int sharedMemPerBlock; /**< Shared memory available per block in bytes */ + int totalConstantMemory; /**< Constant memory available on device in bytes */ + int SIMDWidth; /**< Warp size in threads */ + int memPitch; /**< Maximum pitch in bytes allowed by memory copies */ + int regsPerBlock; /**< 32-bit registers available per block */ + int clockRate; /**< Clock frequency in kilohertz */ + int textureAlign; /**< Alignment requirement for textures */ +} CUdevprop_v1; +typedef CUdevprop_v1 CUdevprop; + +/** + * Pointer information + */ +typedef enum CUpointer_attribute_enum { + CU_POINTER_ATTRIBUTE_CONTEXT = 1, /**< The ::CUcontext on which a pointer was allocated or registered */ + CU_POINTER_ATTRIBUTE_MEMORY_TYPE = 2, /**< The ::CUmemorytype describing the physical location of a pointer */ + CU_POINTER_ATTRIBUTE_DEVICE_POINTER = 3, /**< The address at which a pointer's memory may be accessed on the device */ + CU_POINTER_ATTRIBUTE_HOST_POINTER = 4, /**< The address at which a pointer's memory may be accessed on the host */ + CU_POINTER_ATTRIBUTE_P2P_TOKENS = 5, /**< A pair of tokens for use with the nv-p2p.h Linux kernel interface */ + CU_POINTER_ATTRIBUTE_SYNC_MEMOPS = 6, /**< Synchronize every synchronous memory operation initiated on this region */ + CU_POINTER_ATTRIBUTE_BUFFER_ID = 7, /**< A process-wide unique ID for an allocated memory region*/ + CU_POINTER_ATTRIBUTE_IS_MANAGED = 8, /**< Indicates if the pointer points to managed memory */ + CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = 9, /**< A device ordinal of a device on which a pointer was allocated or registered */ + CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE = 10, /**< 1 if this pointer maps to an allocation that is suitable for ::cudaIpcGetMemHandle, 0 otherwise **/ + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR = 11, /**< Starting address for this requested pointer */ + CU_POINTER_ATTRIBUTE_RANGE_SIZE = 12, /**< Size of the address range for this requested pointer */ + CU_POINTER_ATTRIBUTE_MAPPED = 13, /**< 1 if this pointer is in a valid address range that is mapped to a backing allocation, 0 otherwise **/ + CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES = 14, /**< Bitmask of allowed ::CUmemAllocationHandleType for this allocation **/ + CU_POINTER_ATTRIBUTE_IS_GPU_DIRECT_RDMA_CAPABLE = 15, /**< 1 if the memory this pointer is referencing can be used with the GPUDirect RDMA API **/ + CU_POINTER_ATTRIBUTE_ACCESS_FLAGS = 16, /**< Returns the access flags the device associated with the current context has on the corresponding memory referenced by the pointer given */ + CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE = 17 /**< Returns the mempool handle for the allocation if it was allocated from a mempool. Otherwise returns NULL. **/ +} CUpointer_attribute; + +/** + * Function properties + */ +typedef enum CUfunction_attribute_enum { + /** + * The maximum number of threads per block, beyond which a launch of the + * function would fail. This number depends on both the function and the + * device on which the function is currently loaded. + */ + CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK = 0, + + /** + * The size in bytes of statically-allocated shared memory required by + * this function. This does not include dynamically-allocated shared + * memory requested by the user at runtime. + */ + CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES = 1, + + /** + * The size in bytes of user-allocated constant memory required by this + * function. + */ + CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES = 2, + + /** + * The size in bytes of local memory used by each thread of this function. + */ + CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES = 3, + + /** + * The number of registers used by each thread of this function. + */ + CU_FUNC_ATTRIBUTE_NUM_REGS = 4, + + /** + * The PTX virtual architecture version for which the function was + * compiled. This value is the major PTX version * 10 + the minor PTX + * version, so a PTX version 1.3 function would return the value 13. + * Note that this may return the undefined value of 0 for cubins + * compiled prior to CUDA 3.0. + */ + CU_FUNC_ATTRIBUTE_PTX_VERSION = 5, + + /** + * The binary architecture version for which the function was compiled. + * This value is the major binary version * 10 + the minor binary version, + * so a binary version 1.3 function would return the value 13. Note that + * this will return a value of 10 for legacy cubins that do not have a + * properly-encoded binary architecture version. + */ + CU_FUNC_ATTRIBUTE_BINARY_VERSION = 6, + + /** + * The attribute to indicate whether the function has been compiled with + * user specified option "-Xptxas --dlcm=ca" set . + */ + CU_FUNC_ATTRIBUTE_CACHE_MODE_CA = 7, + + /** + * The maximum size in bytes of dynamically-allocated shared memory that can be used by + * this function. If the user-specified dynamic shared memory size is larger than this + * value, the launch will fail. + * See ::cuFuncSetAttribute + */ + CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES = 8, + + /** + * On devices where the L1 cache and shared memory use the same hardware resources, + * this sets the shared memory carveout preference, in percent of the total shared memory. + * Refer to ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR. + * This is only a hint, and the driver can choose a different ratio if required to execute the function. + * See ::cuFuncSetAttribute + */ + CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT = 9, + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + CU_FUNC_ATTRIBUTE_MAX +} CUfunction_attribute; + +/** + * Function cache configurations + */ +typedef enum CUfunc_cache_enum { + CU_FUNC_CACHE_PREFER_NONE = 0x00, /**< no preference for shared memory or L1 (default) */ + CU_FUNC_CACHE_PREFER_SHARED = 0x01, /**< prefer larger shared memory and smaller L1 cache */ + CU_FUNC_CACHE_PREFER_L1 = 0x02, /**< prefer larger L1 cache and smaller shared memory */ + CU_FUNC_CACHE_PREFER_EQUAL = 0x03 /**< prefer equal sized L1 cache and shared memory */ +} CUfunc_cache; + +/** + * Shared memory configurations + */ +typedef enum CUsharedconfig_enum { + CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE = 0x00, /**< set default shared memory bank size */ + CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE = 0x01, /**< set shared memory bank width to four bytes */ + CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE = 0x02 /**< set shared memory bank width to eight bytes */ +} CUsharedconfig; + +/** + * Shared memory carveout configurations. These may be passed to ::cuFuncSetAttribute + */ +typedef enum CUshared_carveout_enum { + CU_SHAREDMEM_CARVEOUT_DEFAULT = -1, /**< No preference for shared memory or L1 (default) */ + CU_SHAREDMEM_CARVEOUT_MAX_SHARED = 100, /**< Prefer maximum available shared memory, minimum L1 cache */ + CU_SHAREDMEM_CARVEOUT_MAX_L1 = 0 /**< Prefer maximum available L1 cache, minimum shared memory */ +} CUshared_carveout; + +/** + * Memory types + */ +typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, /**< Host memory */ + CU_MEMORYTYPE_DEVICE = 0x02, /**< Device memory */ + CU_MEMORYTYPE_ARRAY = 0x03, /**< Array memory */ + CU_MEMORYTYPE_UNIFIED = 0x04 /**< Unified device or host memory */ +} CUmemorytype; + +/** + * Compute Modes + */ +typedef enum CUcomputemode_enum { + CU_COMPUTEMODE_DEFAULT = 0, /**< Default compute mode (Multiple contexts allowed per device) */ + CU_COMPUTEMODE_PROHIBITED = 2, /**< Compute-prohibited mode (No contexts can be created on this device at this time) */ + CU_COMPUTEMODE_EXCLUSIVE_PROCESS = 3 /**< Compute-exclusive-process mode (Only one context used by a single process can be present on this device at a time) */ +} CUcomputemode; + +/** + * Memory advise values + */ +typedef enum CUmem_advise_enum { + CU_MEM_ADVISE_SET_READ_MOSTLY = 1, /**< Data will mostly be read and only occassionally be written to */ + CU_MEM_ADVISE_UNSET_READ_MOSTLY = 2, /**< Undo the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY */ + CU_MEM_ADVISE_SET_PREFERRED_LOCATION = 3, /**< Set the preferred location for the data as the specified device */ + CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION = 4, /**< Clear the preferred location for the data */ + CU_MEM_ADVISE_SET_ACCESSED_BY = 5, /**< Data will be accessed by the specified device, so prevent page faults as much as possible */ + CU_MEM_ADVISE_UNSET_ACCESSED_BY = 6 /**< Let the Unified Memory subsystem decide on the page faulting policy for the specified device */ +} CUmem_advise; + +typedef enum CUmem_range_attribute_enum { + CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY = 1, /**< Whether the range will mostly be read and only occassionally be written to */ + CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION = 2, /**< The preferred location of the range */ + CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY = 3, /**< Memory range has ::CU_MEM_ADVISE_SET_ACCESSED_BY set for specified device */ + CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION = 4 /**< The last location to which the range was prefetched */ +} CUmem_range_attribute; + +/** + * Online compiler and linker options + */ +typedef enum CUjit_option_enum +{ + /** + * Max number of registers that a thread may use.\n + * Option type: unsigned int\n + * Applies to: compiler only + */ + CU_JIT_MAX_REGISTERS = 0, + + /** + * IN: Specifies minimum number of threads per block to target compilation + * for\n + * OUT: Returns the number of threads the compiler actually targeted. + * This restricts the resource utilization fo the compiler (e.g. max + * registers) such that a block with the given number of threads should be + * able to launch based on register limitations. Note, this option does not + * currently take into account any other resource limitations, such as + * shared memory utilization.\n + * Cannot be combined with ::CU_JIT_TARGET.\n + * Option type: unsigned int\n + * Applies to: compiler only + */ + CU_JIT_THREADS_PER_BLOCK, + + /** + * Overwrites the option value with the total wall clock time, in + * milliseconds, spent in the compiler and linker\n + * Option type: float\n + * Applies to: compiler and linker + */ + CU_JIT_WALL_TIME, + + /** + * Pointer to a buffer in which to print any log messages + * that are informational in nature (the buffer size is specified via + * option ::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES)\n + * Option type: char *\n + * Applies to: compiler and linker + */ + CU_JIT_INFO_LOG_BUFFER, + + /** + * IN: Log buffer size in bytes. Log messages will be capped at this size + * (including null terminator)\n + * OUT: Amount of log buffer filled with messages\n + * Option type: unsigned int\n + * Applies to: compiler and linker + */ + CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, + + /** + * Pointer to a buffer in which to print any log messages that + * reflect errors (the buffer size is specified via option + * ::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES)\n + * Option type: char *\n + * Applies to: compiler and linker + */ + CU_JIT_ERROR_LOG_BUFFER, + + /** + * IN: Log buffer size in bytes. Log messages will be capped at this size + * (including null terminator)\n + * OUT: Amount of log buffer filled with messages\n + * Option type: unsigned int\n + * Applies to: compiler and linker + */ + CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, + + /** + * Level of optimizations to apply to generated code (0 - 4), with 4 + * being the default and highest level of optimizations.\n + * Option type: unsigned int\n + * Applies to: compiler only + */ + CU_JIT_OPTIMIZATION_LEVEL, + + /** + * No option value required. Determines the target based on the current + * attached context (default)\n + * Option type: No option value needed\n + * Applies to: compiler and linker + */ + CU_JIT_TARGET_FROM_CUCONTEXT, + + /** + * Target is chosen based on supplied ::CUjit_target. Cannot be + * combined with ::CU_JIT_THREADS_PER_BLOCK.\n + * Option type: unsigned int for enumerated type ::CUjit_target\n + * Applies to: compiler and linker + */ + CU_JIT_TARGET, + + /** + * Specifies choice of fallback strategy if matching cubin is not found. + * Choice is based on supplied ::CUjit_fallback. This option cannot be + * used with cuLink* APIs as the linker requires exact matches.\n + * Option type: unsigned int for enumerated type ::CUjit_fallback\n + * Applies to: compiler only + */ + CU_JIT_FALLBACK_STRATEGY, + + /** + * Specifies whether to create debug information in output (-g) + * (0: false, default)\n + * Option type: int\n + * Applies to: compiler and linker + */ + CU_JIT_GENERATE_DEBUG_INFO, + + /** + * Generate verbose log messages (0: false, default)\n + * Option type: int\n + * Applies to: compiler and linker + */ + CU_JIT_LOG_VERBOSE, + + /** + * Generate line number information (-lineinfo) (0: false, default)\n + * Option type: int\n + * Applies to: compiler only + */ + CU_JIT_GENERATE_LINE_INFO, + + /** + * Specifies whether to enable caching explicitly (-dlcm) \n + * Choice is based on supplied ::CUjit_cacheMode_enum.\n + * Option type: unsigned int for enumerated type ::CUjit_cacheMode_enum\n + * Applies to: compiler only + */ + CU_JIT_CACHE_MODE, + + /** + * The below jit options are used for internal purposes only, in this version of CUDA + */ + CU_JIT_NEW_SM3X_OPT, + CU_JIT_FAST_COMPILE, + + /** + * Array of device symbol names that will be relocated to the corresponing + * host addresses stored in ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES.\n + * Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n + * When loding a device module, driver will relocate all encountered + * unresolved symbols to the host addresses.\n + * It is only allowed to register symbols that correspond to unresolved + * global variables.\n + * It is illegal to register the same device symbol at multiple addresses.\n + * Option type: const char **\n + * Applies to: dynamic linker only + */ + CU_JIT_GLOBAL_SYMBOL_NAMES, + + /** + * Array of host addresses that will be used to relocate corresponding + * device symbols stored in ::CU_JIT_GLOBAL_SYMBOL_NAMES.\n + * Must contain ::CU_JIT_GLOBAL_SYMBOL_COUNT entries.\n + * Option type: void **\n + * Applies to: dynamic linker only + */ + CU_JIT_GLOBAL_SYMBOL_ADDRESSES, + + /** + * Number of entries in ::CU_JIT_GLOBAL_SYMBOL_NAMES and + * ::CU_JIT_GLOBAL_SYMBOL_ADDRESSES arrays.\n + * Option type: unsigned int\n + * Applies to: dynamic linker only + */ + CU_JIT_GLOBAL_SYMBOL_COUNT, + + /** + * Enable link-time optimization (-dlto) for device code (0: false, default).\n + * This option is not supported on 32-bit platforms.\n + * Option type: int\n + * Applies to: compiler and linker + */ + CU_JIT_LTO, + + /** + * Control single-precision denormals (-ftz) support (0: false, default). + * 1 : flushes denormal values to zero + * 0 : preserves denormal values + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_FTZ, + + /** + * Control single-precision floating-point division and reciprocals + * (-prec-div) support (1: true, default). + * 1 : Enables the IEEE round-to-nearest mode + * 0 : Enables the fast approximation mode + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_PREC_DIV, + + /** + * Control single-precision floating-point square root + * (-prec-sqrt) support (1: true, default). + * 1 : Enables the IEEE round-to-nearest mode + * 0 : Enables the fast approximation mode + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_PREC_SQRT, + + /** + * Enable/Disable the contraction of floating-point multiplies + * and adds/subtracts into floating-point multiply-add (-fma) + * operations (1: Enable, default; 0: Disable). + * Option type: int\n + * Applies to: link-time optimization specified with CU_JIT_LTO + */ + CU_JIT_FMA, + + CU_JIT_NUM_OPTIONS + +} CUjit_option; + +/** + * Online compilation targets + */ +typedef enum CUjit_target_enum +{ + + CU_TARGET_COMPUTE_20 = 20, /**< Compute device class 2.0 */ + CU_TARGET_COMPUTE_21 = 21, /**< Compute device class 2.1 */ + + + CU_TARGET_COMPUTE_30 = 30, /**< Compute device class 3.0 */ + CU_TARGET_COMPUTE_32 = 32, /**< Compute device class 3.2 */ + CU_TARGET_COMPUTE_35 = 35, /**< Compute device class 3.5 */ + CU_TARGET_COMPUTE_37 = 37, /**< Compute device class 3.7 */ + + + CU_TARGET_COMPUTE_50 = 50, /**< Compute device class 5.0 */ + CU_TARGET_COMPUTE_52 = 52, /**< Compute device class 5.2 */ + CU_TARGET_COMPUTE_53 = 53, /**< Compute device class 5.3 */ + + + CU_TARGET_COMPUTE_60 = 60, /**< Compute device class 6.0.*/ + CU_TARGET_COMPUTE_61 = 61, /**< Compute device class 6.1.*/ + CU_TARGET_COMPUTE_62 = 62, /**< Compute device class 6.2.*/ + + + CU_TARGET_COMPUTE_70 = 70, /**< Compute device class 7.0.*/ + CU_TARGET_COMPUTE_72 = 72, /**< Compute device class 7.2.*/ + + CU_TARGET_COMPUTE_75 = 75, /**< Compute device class 7.5.*/ + + CU_TARGET_COMPUTE_80 = 80, /**< Compute device class 8.0.*/ + CU_TARGET_COMPUTE_86 = 86 /**< Compute device class 8.6.*/ + +} CUjit_target; + +/** + * Cubin matching fallback strategies + */ +typedef enum CUjit_fallback_enum +{ + CU_PREFER_PTX = 0, /**< Prefer to compile ptx if exact binary match not found */ + + CU_PREFER_BINARY /**< Prefer to fall back to compatible binary code if exact match not found */ + +} CUjit_fallback; + +/** + * Caching modes for dlcm + */ +typedef enum CUjit_cacheMode_enum +{ + CU_JIT_CACHE_OPTION_NONE = 0, /**< Compile with no -dlcm flag specified */ + CU_JIT_CACHE_OPTION_CG, /**< Compile with L1 cache disabled */ + CU_JIT_CACHE_OPTION_CA /**< Compile with L1 cache enabled */ +} CUjit_cacheMode; + +/** + * Device code formats + */ +typedef enum CUjitInputType_enum +{ + /** + * Compiled device-class-specific device code\n + * Applicable options: none + */ + CU_JIT_INPUT_CUBIN = 0, + + /** + * PTX source code\n + * Applicable options: PTX compiler options + */ + CU_JIT_INPUT_PTX, + + /** + * Bundle of multiple cubins and/or PTX of some device code\n + * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY + */ + CU_JIT_INPUT_FATBINARY, + + /** + * Host object with embedded device code\n + * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY + */ + CU_JIT_INPUT_OBJECT, + + /** + * Archive of host objects with embedded device code\n + * Applicable options: PTX compiler options, ::CU_JIT_FALLBACK_STRATEGY + */ + CU_JIT_INPUT_LIBRARY, + + /** + * High-level intermediate code for link-time optimization\n + * Applicable options: NVVM compiler options, PTX compiler options + */ + CU_JIT_INPUT_NVVM, + + CU_JIT_NUM_INPUT_TYPES +} CUjitInputType; + +typedef struct CUlinkState_st *CUlinkState; + +/** + * Flags to register a graphics resource + */ +typedef enum CUgraphicsRegisterFlags_enum { + CU_GRAPHICS_REGISTER_FLAGS_NONE = 0x00, + CU_GRAPHICS_REGISTER_FLAGS_READ_ONLY = 0x01, + CU_GRAPHICS_REGISTER_FLAGS_WRITE_DISCARD = 0x02, + CU_GRAPHICS_REGISTER_FLAGS_SURFACE_LDST = 0x04, + CU_GRAPHICS_REGISTER_FLAGS_TEXTURE_GATHER = 0x08 +} CUgraphicsRegisterFlags; + +/** + * Flags for mapping and unmapping interop resources + */ +typedef enum CUgraphicsMapResourceFlags_enum { + CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE = 0x00, + CU_GRAPHICS_MAP_RESOURCE_FLAGS_READ_ONLY = 0x01, + CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITE_DISCARD = 0x02 +} CUgraphicsMapResourceFlags; + +/** + * Array indices for cube faces + */ +typedef enum CUarray_cubemap_face_enum { + CU_CUBEMAP_FACE_POSITIVE_X = 0x00, /**< Positive X face of cubemap */ + CU_CUBEMAP_FACE_NEGATIVE_X = 0x01, /**< Negative X face of cubemap */ + CU_CUBEMAP_FACE_POSITIVE_Y = 0x02, /**< Positive Y face of cubemap */ + CU_CUBEMAP_FACE_NEGATIVE_Y = 0x03, /**< Negative Y face of cubemap */ + CU_CUBEMAP_FACE_POSITIVE_Z = 0x04, /**< Positive Z face of cubemap */ + CU_CUBEMAP_FACE_NEGATIVE_Z = 0x05 /**< Negative Z face of cubemap */ +} CUarray_cubemap_face; + +/** + * Limits + */ +typedef enum CUlimit_enum { + CU_LIMIT_STACK_SIZE = 0x00, /**< GPU thread stack size */ + CU_LIMIT_PRINTF_FIFO_SIZE = 0x01, /**< GPU printf FIFO size */ + CU_LIMIT_MALLOC_HEAP_SIZE = 0x02, /**< GPU malloc heap size */ + CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH = 0x03, /**< GPU device runtime launch synchronize depth */ + CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT = 0x04, /**< GPU device runtime pending launch count */ + CU_LIMIT_MAX_L2_FETCH_GRANULARITY = 0x05, /**< A value between 0 and 128 that indicates the maximum fetch granularity of L2 (in Bytes). This is a hint */ + CU_LIMIT_PERSISTING_L2_CACHE_SIZE = 0x06, /**< A size in bytes for L2 persisting lines cache size */ + CU_LIMIT_MAX +} CUlimit; + +/** + * Resource types + */ +typedef enum CUresourcetype_enum { + CU_RESOURCE_TYPE_ARRAY = 0x00, /**< Array resoure */ + CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, /**< Mipmapped array resource */ + CU_RESOURCE_TYPE_LINEAR = 0x02, /**< Linear resource */ + CU_RESOURCE_TYPE_PITCH2D = 0x03 /**< Pitch 2D resource */ +} CUresourcetype; + +#ifdef _WIN32 +#define CUDA_CB __stdcall +#else +#define CUDA_CB +#endif + +/** + * CUDA host function + * \param userData Argument value passed to the function + */ +typedef void (CUDA_CB *CUhostFn)(void *userData); + +/** + * Specifies performance hint with ::CUaccessPolicyWindow for hitProp and missProp members. + */ +typedef enum CUaccessProperty_enum { + CU_ACCESS_PROPERTY_NORMAL = 0, /**< Normal cache persistence. */ + CU_ACCESS_PROPERTY_STREAMING = 1, /**< Streaming access is less likely to persit from cache. */ + CU_ACCESS_PROPERTY_PERSISTING = 2 /**< Persisting access is more likely to persist in cache.*/ +} CUaccessProperty; + +/** + * Specifies an access policy for a window, a contiguous extent of memory + * beginning at base_ptr and ending at base_ptr + num_bytes. + * num_bytes is limited by CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE. + * Partition into many segments and assign segments such that: + * sum of "hit segments" / window == approx. ratio. + * sum of "miss segments" / window == approx 1-ratio. + * Segments and ratio specifications are fitted to the capabilities of + * the architecture. + * Accesses in a hit segment apply the hitProp access policy. + * Accesses in a miss segment apply the missProp access policy. + */ +typedef struct CUaccessPolicyWindow_st { + void *base_ptr; /**< Starting address of the access policy window. CUDA driver may align it. */ + size_t num_bytes; /**< Size in bytes of the window policy. CUDA driver may restrict the maximum size and alignment. */ + float hitRatio; /**< hitRatio specifies percentage of lines assigned hitProp, rest are assigned missProp. */ + CUaccessProperty hitProp; /**< ::CUaccessProperty set for hit. */ + CUaccessProperty missProp; /**< ::CUaccessProperty set for miss. Must be either NORMAL or STREAMING */ +} CUaccessPolicyWindow_v1; +typedef CUaccessPolicyWindow_v1 CUaccessPolicyWindow; + +/** + * GPU kernel node parameters + */ +typedef struct CUDA_KERNEL_NODE_PARAMS_st { + CUfunction func; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block in bytes */ + void **kernelParams; /**< Array of pointers to kernel parameters */ + void **extra; /**< Extra options */ +} CUDA_KERNEL_NODE_PARAMS_v1; +typedef CUDA_KERNEL_NODE_PARAMS_v1 CUDA_KERNEL_NODE_PARAMS; + +/** + * Memset node parameters + */ +typedef struct CUDA_MEMSET_NODE_PARAMS_st { + CUdeviceptr dst; /**< Destination device pointer */ + size_t pitch; /**< Pitch of destination device pointer. Unused if height is 1 */ + unsigned int value; /**< Value to be set */ + unsigned int elementSize; /**< Size of each element in bytes. Must be 1, 2, or 4. */ + size_t width; /**< Width of the row in elements */ + size_t height; /**< Number of rows */ +} CUDA_MEMSET_NODE_PARAMS_v1; +typedef CUDA_MEMSET_NODE_PARAMS_v1 CUDA_MEMSET_NODE_PARAMS; + +/** + * Host node parameters + */ +typedef struct CUDA_HOST_NODE_PARAMS_st { + CUhostFn fn; /**< The function to call when the node executes */ + void* userData; /**< Argument to pass to the function */ +} CUDA_HOST_NODE_PARAMS_v1; +typedef CUDA_HOST_NODE_PARAMS_v1 CUDA_HOST_NODE_PARAMS; + +/** + * Graph node types + */ +typedef enum CUgraphNodeType_enum { + CU_GRAPH_NODE_TYPE_KERNEL = 0, /**< GPU kernel node */ + CU_GRAPH_NODE_TYPE_MEMCPY = 1, /**< Memcpy node */ + CU_GRAPH_NODE_TYPE_MEMSET = 2, /**< Memset node */ + CU_GRAPH_NODE_TYPE_HOST = 3, /**< Host (executable) node */ + CU_GRAPH_NODE_TYPE_GRAPH = 4, /**< Node which executes an embedded graph */ + CU_GRAPH_NODE_TYPE_EMPTY = 5, /**< Empty (no-op) node */ + CU_GRAPH_NODE_TYPE_WAIT_EVENT = 6, /**< External event wait node */ + CU_GRAPH_NODE_TYPE_EVENT_RECORD = 7, /**< External event record node */ + CU_GRAPH_NODE_TYPE_EXT_SEMAS_SIGNAL = 8, /**< External semaphore signal node */ + CU_GRAPH_NODE_TYPE_EXT_SEMAS_WAIT = 9, /**< External semaphore wait node */ + CU_GRAPH_NODE_TYPE_MEM_ALLOC = 10,/**< Memory Allocation Node */ + CU_GRAPH_NODE_TYPE_MEM_FREE = 11 /**< Memory Free Node */ +} CUgraphNodeType; + +typedef enum CUsynchronizationPolicy_enum { + CU_SYNC_POLICY_AUTO = 1, + CU_SYNC_POLICY_SPIN = 2, + CU_SYNC_POLICY_YIELD = 3, + CU_SYNC_POLICY_BLOCKING_SYNC = 4 +} CUsynchronizationPolicy; + +/** + * Graph kernel node Attributes + */ +typedef enum CUkernelNodeAttrID_enum { + CU_KERNEL_NODE_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, /**< Identifier for ::CUkernelNodeAttrValue::accessPolicyWindow. */ + CU_KERNEL_NODE_ATTRIBUTE_COOPERATIVE = 2 /**< Allows a kernel node to be cooperative (see ::cuLaunchCooperativeKernel). */ +} CUkernelNodeAttrID; + +/** + * Graph kernel node attributes union, used with ::cuKernelNodeSetAttribute/::cuKernelNodeGetAttribute + */ +typedef union CUkernelNodeAttrValue_union { + CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ + int cooperative; /**< Nonzero indicates a cooperative kernel (see ::cuLaunchCooperativeKernel). */ +} CUkernelNodeAttrValue_v1; +typedef CUkernelNodeAttrValue_v1 CUkernelNodeAttrValue; + +/** + * Possible stream capture statuses returned by ::cuStreamIsCapturing + */ +typedef enum CUstreamCaptureStatus_enum { + CU_STREAM_CAPTURE_STATUS_NONE = 0, /**< Stream is not capturing */ + CU_STREAM_CAPTURE_STATUS_ACTIVE = 1, /**< Stream is actively capturing */ + CU_STREAM_CAPTURE_STATUS_INVALIDATED = 2 /**< Stream is part of a capture sequence that + has been invalidated, but not terminated */ +} CUstreamCaptureStatus; + +/** + * Possible modes for stream capture thread interactions. For more details see + * ::cuStreamBeginCapture and ::cuThreadExchangeStreamCaptureMode + */ +typedef enum CUstreamCaptureMode_enum { + CU_STREAM_CAPTURE_MODE_GLOBAL = 0, + CU_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1, + CU_STREAM_CAPTURE_MODE_RELAXED = 2 +} CUstreamCaptureMode; + +/** + * Stream Attributes + */ +typedef enum CUstreamAttrID_enum { + CU_STREAM_ATTRIBUTE_ACCESS_POLICY_WINDOW = 1, /**< Identifier for ::CUstreamAttrValue::accessPolicyWindow. */ + CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY = 3 /**< ::CUsynchronizationPolicy for work queued up in this stream */ +} CUstreamAttrID; + +/** + * Stream attributes union, used with ::cuStreamSetAttribute/::cuStreamGetAttribute + */ +typedef union CUstreamAttrValue_union { + CUaccessPolicyWindow accessPolicyWindow; /**< Attribute ::CUaccessPolicyWindow. */ + CUsynchronizationPolicy syncPolicy; /**< Value for ::CU_STREAM_ATTRIBUTE_SYNCHRONIZATION_POLICY. */ +} CUstreamAttrValue_v1; +typedef CUstreamAttrValue_v1 CUstreamAttrValue; + +/** + * Flags to specify search options. For more details see ::cuGetProcAddress + */ +typedef enum CUdriverProcAddress_flags_enum { + CU_GET_PROC_ADDRESS_DEFAULT = 0, /**< Default search mode for driver symbols. */ + CU_GET_PROC_ADDRESS_LEGACY_STREAM = 1 << 0, /**< Search for legacy versions of driver symbols. */ + CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM = 1 << 1 /**< Search for per-thread versions of driver symbols. */ +} CUdriverProcAddress_flags; + +/** + * Execution Affinity Types + */ +typedef enum CUexecAffinityType_enum { + CU_EXEC_AFFINITY_TYPE_SM_COUNT = 0, /**< Create a context with limited SMs. */ + CU_EXEC_AFFINITY_TYPE_MAX +} CUexecAffinityType; + +/** + * Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT + */ +typedef struct CUexecAffinitySmCount_st { + unsigned int val; /**< The number of SMs the context is limited to use. */ +} CUexecAffinitySmCount_v1; +typedef CUexecAffinitySmCount_v1 CUexecAffinitySmCount; + +/** + * Execution Affinity Parameters + */ +typedef struct CUexecAffinityParam_st { + CUexecAffinityType type; + union { + CUexecAffinitySmCount smCount; /** Value for ::CU_EXEC_AFFINITY_TYPE_SM_COUNT */ + } param; +} CUexecAffinityParam_v1; +typedef CUexecAffinityParam_v1 CUexecAffinityParam; + +/** + * Error codes + */ +typedef enum cudaError_enum { + /** + * The API call returned with no errors. In the case of query calls, this + * also means that the operation being queried is complete (see + * ::cuEventQuery() and ::cuStreamQuery()). + */ + CUDA_SUCCESS = 0, + + /** + * This indicates that one or more of the parameters passed to the API call + * is not within an acceptable range of values. + */ + CUDA_ERROR_INVALID_VALUE = 1, + + /** + * The API call failed because it was unable to allocate enough memory to + * perform the requested operation. + */ + CUDA_ERROR_OUT_OF_MEMORY = 2, + + /** + * This indicates that the CUDA driver has not been initialized with + * ::cuInit() or that initialization has failed. + */ + CUDA_ERROR_NOT_INITIALIZED = 3, + + /** + * This indicates that the CUDA driver is in the process of shutting down. + */ + CUDA_ERROR_DEINITIALIZED = 4, + + /** + * This indicates profiler is not initialized for this run. This can + * happen when the application is running with external profiling tools + * like visual profiler. + */ + CUDA_ERROR_PROFILER_DISABLED = 5, + + /** + * \deprecated + * This error return is deprecated as of CUDA 5.0. It is no longer an error + * to attempt to enable/disable the profiling via ::cuProfilerStart or + * ::cuProfilerStop without initialization. + */ + CUDA_ERROR_PROFILER_NOT_INITIALIZED = 6, + + /** + * \deprecated + * This error return is deprecated as of CUDA 5.0. It is no longer an error + * to call cuProfilerStart() when profiling is already enabled. + */ + CUDA_ERROR_PROFILER_ALREADY_STARTED = 7, + + /** + * \deprecated + * This error return is deprecated as of CUDA 5.0. It is no longer an error + * to call cuProfilerStop() when profiling is already disabled. + */ + CUDA_ERROR_PROFILER_ALREADY_STOPPED = 8, + + /** + * This indicates that the CUDA driver that the application has loaded is a + * stub library. Applications that run with the stub rather than a real + * driver loaded will result in CUDA API returning this error. + */ + CUDA_ERROR_STUB_LIBRARY = 34, + + /** + * This indicates that no CUDA-capable devices were detected by the installed + * CUDA driver. + */ + CUDA_ERROR_NO_DEVICE = 100, + + /** + * This indicates that the device ordinal supplied by the user does not + * correspond to a valid CUDA device or that the action requested is + * invalid for the specified device. + */ + CUDA_ERROR_INVALID_DEVICE = 101, + + /** + * This error indicates that the Grid license is not applied. + */ + CUDA_ERROR_DEVICE_NOT_LICENSED = 102, + + /** + * This indicates that the device kernel image is invalid. This can also + * indicate an invalid CUDA module. + */ + CUDA_ERROR_INVALID_IMAGE = 200, + + /** + * This most frequently indicates that there is no context bound to the + * current thread. This can also be returned if the context passed to an + * API call is not a valid handle (such as a context that has had + * ::cuCtxDestroy() invoked on it). This can also be returned if a user + * mixes different API versions (i.e. 3010 context with 3020 API calls). + * See ::cuCtxGetApiVersion() for more details. + */ + CUDA_ERROR_INVALID_CONTEXT = 201, + + /** + * This indicated that the context being supplied as a parameter to the + * API call was already the active context. + * \deprecated + * This error return is deprecated as of CUDA 3.2. It is no longer an + * error to attempt to push the active context via ::cuCtxPushCurrent(). + */ + CUDA_ERROR_CONTEXT_ALREADY_CURRENT = 202, + + /** + * This indicates that a map or register operation has failed. + */ + CUDA_ERROR_MAP_FAILED = 205, + + /** + * This indicates that an unmap or unregister operation has failed. + */ + CUDA_ERROR_UNMAP_FAILED = 206, + + /** + * This indicates that the specified array is currently mapped and thus + * cannot be destroyed. + */ + CUDA_ERROR_ARRAY_IS_MAPPED = 207, + + /** + * This indicates that the resource is already mapped. + */ + CUDA_ERROR_ALREADY_MAPPED = 208, + + /** + * This indicates that there is no kernel image available that is suitable + * for the device. This can occur when a user specifies code generation + * options for a particular CUDA source file that do not include the + * corresponding device configuration. + */ + CUDA_ERROR_NO_BINARY_FOR_GPU = 209, + + /** + * This indicates that a resource has already been acquired. + */ + CUDA_ERROR_ALREADY_ACQUIRED = 210, + + /** + * This indicates that a resource is not mapped. + */ + CUDA_ERROR_NOT_MAPPED = 211, + + /** + * This indicates that a mapped resource is not available for access as an + * array. + */ + CUDA_ERROR_NOT_MAPPED_AS_ARRAY = 212, + + /** + * This indicates that a mapped resource is not available for access as a + * pointer. + */ + CUDA_ERROR_NOT_MAPPED_AS_POINTER = 213, + + /** + * This indicates that an uncorrectable ECC error was detected during + * execution. + */ + CUDA_ERROR_ECC_UNCORRECTABLE = 214, + + /** + * This indicates that the ::CUlimit passed to the API call is not + * supported by the active device. + */ + CUDA_ERROR_UNSUPPORTED_LIMIT = 215, + + /** + * This indicates that the ::CUcontext passed to the API call can + * only be bound to a single CPU thread at a time but is already + * bound to a CPU thread. + */ + CUDA_ERROR_CONTEXT_ALREADY_IN_USE = 216, + + /** + * This indicates that peer access is not supported across the given + * devices. + */ + CUDA_ERROR_PEER_ACCESS_UNSUPPORTED = 217, + + /** + * This indicates that a PTX JIT compilation failed. + */ + CUDA_ERROR_INVALID_PTX = 218, + + /** + * This indicates an error with OpenGL or DirectX context. + */ + CUDA_ERROR_INVALID_GRAPHICS_CONTEXT = 219, + + /** + * This indicates that an uncorrectable NVLink error was detected during the + * execution. + */ + CUDA_ERROR_NVLINK_UNCORRECTABLE = 220, + + /** + * This indicates that the PTX JIT compiler library was not found. + */ + CUDA_ERROR_JIT_COMPILER_NOT_FOUND = 221, + + /** + * This indicates that the provided PTX was compiled with an unsupported toolchain. + */ + + CUDA_ERROR_UNSUPPORTED_PTX_VERSION = 222, + + /** + * This indicates that the PTX JIT compilation was disabled. + */ + CUDA_ERROR_JIT_COMPILATION_DISABLED = 223, + + /** + * This indicates that the ::CUexecAffinityType passed to the API call is not + * supported by the active device. + */ + CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY = 224, + + /** + * This indicates that the device kernel source is invalid. This includes + * compilation/linker errors encountered in device code or user error. + */ + CUDA_ERROR_INVALID_SOURCE = 300, + + /** + * This indicates that the file specified was not found. + */ + CUDA_ERROR_FILE_NOT_FOUND = 301, + + /** + * This indicates that a link to a shared object failed to resolve. + */ + CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND = 302, + + /** + * This indicates that initialization of a shared object failed. + */ + CUDA_ERROR_SHARED_OBJECT_INIT_FAILED = 303, + + /** + * This indicates that an OS call failed. + */ + CUDA_ERROR_OPERATING_SYSTEM = 304, + + /** + * This indicates that a resource handle passed to the API call was not + * valid. Resource handles are opaque types like ::CUstream and ::CUevent. + */ + CUDA_ERROR_INVALID_HANDLE = 400, + + /** + * This indicates that a resource required by the API call is not in a + * valid state to perform the requested operation. + */ + CUDA_ERROR_ILLEGAL_STATE = 401, + + /** + * This indicates that a named symbol was not found. Examples of symbols + * are global/constant variable names, driver function names, texture names, + * and surface names. + */ + CUDA_ERROR_NOT_FOUND = 500, + + /** + * This indicates that asynchronous operations issued previously have not + * completed yet. This result is not actually an error, but must be indicated + * differently than ::CUDA_SUCCESS (which indicates completion). Calls that + * may return this value include ::cuEventQuery() and ::cuStreamQuery(). + */ + CUDA_ERROR_NOT_READY = 600, + + /** + * While executing a kernel, the device encountered a + * load or store instruction on an invalid memory address. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_ILLEGAL_ADDRESS = 700, + + /** + * This indicates that a launch did not occur because it did not have + * appropriate resources. This error usually indicates that the user has + * attempted to pass too many arguments to the device kernel, or the + * kernel launch specifies too many threads for the kernel's register + * count. Passing arguments of the wrong size (i.e. a 64-bit pointer + * when a 32-bit int is expected) is equivalent to passing too many + * arguments and can also result in this error. + */ + CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES = 701, + + /** + * This indicates that the device kernel took too long to execute. This can + * only occur if timeouts are enabled - see the device attribute + * ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT for more information. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_LAUNCH_TIMEOUT = 702, + + /** + * This error indicates a kernel launch that uses an incompatible texturing + * mode. + */ + CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING = 703, + + /** + * This error indicates that a call to ::cuCtxEnablePeerAccess() is + * trying to re-enable peer access to a context which has already + * had peer access to it enabled. + */ + CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED = 704, + + /** + * This error indicates that ::cuCtxDisablePeerAccess() is + * trying to disable peer access which has not been enabled yet + * via ::cuCtxEnablePeerAccess(). + */ + CUDA_ERROR_PEER_ACCESS_NOT_ENABLED = 705, + + /** + * This error indicates that the primary context for the specified device + * has already been initialized. + */ + CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE = 708, + + /** + * This error indicates that the context current to the calling thread + * has been destroyed using ::cuCtxDestroy, or is a primary context which + * has not yet been initialized. + */ + CUDA_ERROR_CONTEXT_IS_DESTROYED = 709, + + /** + * A device-side assert triggered during kernel execution. The context + * cannot be used anymore, and must be destroyed. All existing device + * memory allocations from this context are invalid and must be + * reconstructed if the program is to continue using CUDA. + */ + CUDA_ERROR_ASSERT = 710, + + /** + * This error indicates that the hardware resources required to enable + * peer access have been exhausted for one or more of the devices + * passed to ::cuCtxEnablePeerAccess(). + */ + CUDA_ERROR_TOO_MANY_PEERS = 711, + + /** + * This error indicates that the memory range passed to ::cuMemHostRegister() + * has already been registered. + */ + CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED = 712, + + /** + * This error indicates that the pointer passed to ::cuMemHostUnregister() + * does not correspond to any currently registered memory region. + */ + CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED = 713, + + /** + * While executing a kernel, the device encountered a stack error. + * This can be due to stack corruption or exceeding the stack size limit. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_HARDWARE_STACK_ERROR = 714, + + /** + * While executing a kernel, the device encountered an illegal instruction. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_ILLEGAL_INSTRUCTION = 715, + + /** + * While executing a kernel, the device encountered a load or store instruction + * on a memory address which is not aligned. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_MISALIGNED_ADDRESS = 716, + + /** + * While executing a kernel, the device encountered an instruction + * which can only operate on memory locations in certain address spaces + * (global, shared, or local), but was supplied a memory address not + * belonging to an allowed address space. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_INVALID_ADDRESS_SPACE = 717, + + /** + * While executing a kernel, the device program counter wrapped its address space. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_INVALID_PC = 718, + + /** + * An exception occurred on the device while executing a kernel. Common + * causes include dereferencing an invalid device pointer and accessing + * out of bounds shared memory. Less common cases can be system specific - more + * information about these cases can be found in the system specific user guide. + * This leaves the process in an inconsistent state and any further CUDA work + * will return the same error. To continue using CUDA, the process must be terminated + * and relaunched. + */ + CUDA_ERROR_LAUNCH_FAILED = 719, + + /** + * This error indicates that the number of blocks launched per grid for a kernel that was + * launched via either ::cuLaunchCooperativeKernel or ::cuLaunchCooperativeKernelMultiDevice + * exceeds the maximum number of blocks as allowed by ::cuOccupancyMaxActiveBlocksPerMultiprocessor + * or ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags times the number of multiprocessors + * as specified by the device attribute ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT. + */ + CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE = 720, + + /** + * This error indicates that the attempted operation is not permitted. + */ + CUDA_ERROR_NOT_PERMITTED = 800, + + /** + * This error indicates that the attempted operation is not supported + * on the current system or device. + */ + CUDA_ERROR_NOT_SUPPORTED = 801, + + /** + * This error indicates that the system is not yet ready to start any CUDA + * work. To continue using CUDA, verify the system configuration is in a + * valid state and all required driver daemons are actively running. + * More information about this error can be found in the system specific + * user guide. + */ + CUDA_ERROR_SYSTEM_NOT_READY = 802, + + /** + * This error indicates that there is a mismatch between the versions of + * the display driver and the CUDA driver. Refer to the compatibility documentation + * for supported versions. + */ + CUDA_ERROR_SYSTEM_DRIVER_MISMATCH = 803, + + /** + * This error indicates that the system was upgraded to run with forward compatibility + * but the visible hardware detected by CUDA does not support this configuration. + * Refer to the compatibility documentation for the supported hardware matrix or ensure + * that only supported hardware is visible during initialization via the CUDA_VISIBLE_DEVICES + * environment variable. + */ + CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE = 804, + + /** + * This error indicates that the MPS client failed to connect to the MPS control daemon or the MPS server. + */ + CUDA_ERROR_MPS_CONNECTION_FAILED = 805, + + /** + * This error indicates that the remote procedural call between the MPS server and the MPS client failed. + */ + CUDA_ERROR_MPS_RPC_FAILURE = 806, + + /** + * This error indicates that the MPS server is not ready to accept new MPS client requests. + * This error can be returned when the MPS server is in the process of recovering from a fatal failure. + */ + CUDA_ERROR_MPS_SERVER_NOT_READY = 807, + + /** + * This error indicates that the hardware resources required to create MPS client have been exhausted. + */ + CUDA_ERROR_MPS_MAX_CLIENTS_REACHED = 808, + + /** + * This error indicates the the hardware resources required to support device connections have been exhausted. + */ + CUDA_ERROR_MPS_MAX_CONNECTIONS_REACHED = 809, + + /** + * This error indicates that the operation is not permitted when + * the stream is capturing. + */ + CUDA_ERROR_STREAM_CAPTURE_UNSUPPORTED = 900, + + /** + * This error indicates that the current capture sequence on the stream + * has been invalidated due to a previous error. + */ + CUDA_ERROR_STREAM_CAPTURE_INVALIDATED = 901, + + /** + * This error indicates that the operation would have resulted in a merge + * of two independent capture sequences. + */ + CUDA_ERROR_STREAM_CAPTURE_MERGE = 902, + + /** + * This error indicates that the capture was not initiated in this stream. + */ + CUDA_ERROR_STREAM_CAPTURE_UNMATCHED = 903, + + /** + * This error indicates that the capture sequence contains a fork that was + * not joined to the primary stream. + */ + CUDA_ERROR_STREAM_CAPTURE_UNJOINED = 904, + + /** + * This error indicates that a dependency would have been created which + * crosses the capture sequence boundary. Only implicit in-stream ordering + * dependencies are allowed to cross the boundary. + */ + CUDA_ERROR_STREAM_CAPTURE_ISOLATION = 905, + + /** + * This error indicates a disallowed implicit dependency on a current capture + * sequence from cudaStreamLegacy. + */ + CUDA_ERROR_STREAM_CAPTURE_IMPLICIT = 906, + + /** + * This error indicates that the operation is not permitted on an event which + * was last recorded in a capturing stream. + */ + CUDA_ERROR_CAPTURED_EVENT = 907, + + /** + * A stream capture sequence not initiated with the ::CU_STREAM_CAPTURE_MODE_RELAXED + * argument to ::cuStreamBeginCapture was passed to ::cuStreamEndCapture in a + * different thread. + */ + CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD = 908, + + /** + * This error indicates that the timeout specified for the wait operation has lapsed. + */ + CUDA_ERROR_TIMEOUT = 909, + + /** + * This error indicates that the graph update was not performed because it included + * changes which violated constraints specific to instantiated graph update. + */ + CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE = 910, + + /** + * This indicates that an async error has occurred in a device outside of CUDA. + * If CUDA was waiting for an external device's signal before consuming shared data, + * the external device signaled an error indicating that the data is not valid for + * consumption. This leaves the process in an inconsistent state and any further CUDA + * work will return the same error. To continue using CUDA, the process must be + * terminated and relaunched. + */ + CUDA_ERROR_EXTERNAL_DEVICE = 911, + + + + + + + + + /** + * This indicates that an unknown internal error has occurred. + */ + CUDA_ERROR_UNKNOWN = 999 +} CUresult; + +/** + * P2P Attributes + */ +typedef enum CUdevice_P2PAttribute_enum { + CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK = 0x01, /**< A relative value indicating the performance of the link between two devices */ + CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED = 0x02, /**< P2P Access is enable */ + CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED = 0x03, /**< Atomic operation over the link supported */ + CU_DEVICE_P2P_ATTRIBUTE_ACCESS_ACCESS_SUPPORTED = 0x04, /**< \deprecated use CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED instead */ + CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED = 0x04 /**< Accessing CUDA arrays over the link supported */ +} CUdevice_P2PAttribute; + + + + + + + + + + + + +/** + * CUDA stream callback + * \param hStream The stream the callback was added to, as passed to ::cuStreamAddCallback. May be NULL. + * \param status ::CUDA_SUCCESS or any persistent error on the stream. + * \param userData User parameter provided at registration. + */ +typedef void (CUDA_CB *CUstreamCallback)(CUstream hStream, CUresult status, void *userData); + +/** + * Block size to per-block dynamic shared memory mapping for a certain + * kernel \param blockSize Block size of the kernel. + * + * \return The dynamic shared memory needed by a block. + */ +typedef size_t (CUDA_CB *CUoccupancyB2DSize)(int blockSize); + +/** + * If set, host memory is portable between CUDA contexts. + * Flag for ::cuMemHostAlloc() + */ +#define CU_MEMHOSTALLOC_PORTABLE 0x01 + +/** + * If set, host memory is mapped into CUDA address space and + * ::cuMemHostGetDevicePointer() may be called on the host pointer. + * Flag for ::cuMemHostAlloc() + */ +#define CU_MEMHOSTALLOC_DEVICEMAP 0x02 + +/** + * If set, host memory is allocated as write-combined - fast to write, + * faster to DMA, slow to read except via SSE4 streaming load instruction + * (MOVNTDQA). + * Flag for ::cuMemHostAlloc() + */ +#define CU_MEMHOSTALLOC_WRITECOMBINED 0x04 + +/** + * If set, host memory is portable between CUDA contexts. + * Flag for ::cuMemHostRegister() + */ +#define CU_MEMHOSTREGISTER_PORTABLE 0x01 + +/** + * If set, host memory is mapped into CUDA address space and + * ::cuMemHostGetDevicePointer() may be called on the host pointer. + * Flag for ::cuMemHostRegister() + */ +#define CU_MEMHOSTREGISTER_DEVICEMAP 0x02 + +/** + * If set, the passed memory pointer is treated as pointing to some + * memory-mapped I/O space, e.g. belonging to a third-party PCIe device. + * On Windows the flag is a no-op. + * On Linux that memory is marked as non cache-coherent for the GPU and + * is expected to be physically contiguous. It may return + * ::CUDA_ERROR_NOT_PERMITTED if run as an unprivileged user, + * ::CUDA_ERROR_NOT_SUPPORTED on older Linux kernel versions. + * On all other platforms, it is not supported and ::CUDA_ERROR_NOT_SUPPORTED + * is returned. + * Flag for ::cuMemHostRegister() + */ +#define CU_MEMHOSTREGISTER_IOMEMORY 0x04 + +/** +* If set, the passed memory pointer is treated as pointing to memory that is +* considered read-only by the device. On platforms without +* ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is +* required in order to register memory mapped to the CPU as read-only. Support +* for the use of this flag can be queried from the device attribute +* ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with +* a current context associated with a device that does not have this attribute +* set will cause ::cuMemHostRegister to error with ::CUDA_ERROR_NOT_SUPPORTED. +*/ +#define CU_MEMHOSTREGISTER_READ_ONLY 0x08 + +/** + * 2D memory copy parameters + */ +typedef struct CUDA_MEMCPY2D_st { + size_t srcXInBytes; /**< Source X in bytes */ + size_t srcY; /**< Source Y */ + + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + size_t srcPitch; /**< Source pitch (ignored when src is array) */ + + size_t dstXInBytes; /**< Destination X in bytes */ + size_t dstY; /**< Destination Y */ + + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + size_t dstPitch; /**< Destination pitch (ignored when dst is array) */ + + size_t WidthInBytes; /**< Width of 2D memory copy in bytes */ + size_t Height; /**< Height of 2D memory copy */ +} CUDA_MEMCPY2D_v2; +typedef CUDA_MEMCPY2D_v2 CUDA_MEMCPY2D; + +/** + * 3D memory copy parameters + */ +typedef struct CUDA_MEMCPY3D_st { + size_t srcXInBytes; /**< Source X in bytes */ + size_t srcY; /**< Source Y */ + size_t srcZ; /**< Source Z */ + size_t srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + void *reserved0; /**< Must be NULL */ + size_t srcPitch; /**< Source pitch (ignored when src is array) */ + size_t srcHeight; /**< Source height (ignored when src is array; may be 0 if Depth==1) */ + + size_t dstXInBytes; /**< Destination X in bytes */ + size_t dstY; /**< Destination Y */ + size_t dstZ; /**< Destination Z */ + size_t dstLOD; /**< Destination LOD */ + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + void *reserved1; /**< Must be NULL */ + size_t dstPitch; /**< Destination pitch (ignored when dst is array) */ + size_t dstHeight; /**< Destination height (ignored when dst is array; may be 0 if Depth==1) */ + + size_t WidthInBytes; /**< Width of 3D memory copy in bytes */ + size_t Height; /**< Height of 3D memory copy */ + size_t Depth; /**< Depth of 3D memory copy */ +} CUDA_MEMCPY3D_v2; +typedef CUDA_MEMCPY3D_v2 CUDA_MEMCPY3D; + +/** + * 3D memory cross-context copy parameters + */ +typedef struct CUDA_MEMCPY3D_PEER_st { + size_t srcXInBytes; /**< Source X in bytes */ + size_t srcY; /**< Source Y */ + size_t srcZ; /**< Source Z */ + size_t srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + CUcontext srcContext; /**< Source context (ignored with srcMemoryType is ::CU_MEMORYTYPE_ARRAY) */ + size_t srcPitch; /**< Source pitch (ignored when src is array) */ + size_t srcHeight; /**< Source height (ignored when src is array; may be 0 if Depth==1) */ + + size_t dstXInBytes; /**< Destination X in bytes */ + size_t dstY; /**< Destination Y */ + size_t dstZ; /**< Destination Z */ + size_t dstLOD; /**< Destination LOD */ + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + CUcontext dstContext; /**< Destination context (ignored with dstMemoryType is ::CU_MEMORYTYPE_ARRAY) */ + size_t dstPitch; /**< Destination pitch (ignored when dst is array) */ + size_t dstHeight; /**< Destination height (ignored when dst is array; may be 0 if Depth==1) */ + + size_t WidthInBytes; /**< Width of 3D memory copy in bytes */ + size_t Height; /**< Height of 3D memory copy */ + size_t Depth; /**< Depth of 3D memory copy */ +} CUDA_MEMCPY3D_PEER_v1; +typedef CUDA_MEMCPY3D_PEER_v1 CUDA_MEMCPY3D_PEER; + +/** + * Array descriptor + */ +typedef struct CUDA_ARRAY_DESCRIPTOR_st +{ + size_t Width; /**< Width of array */ + size_t Height; /**< Height of array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ +} CUDA_ARRAY_DESCRIPTOR_v2; +typedef CUDA_ARRAY_DESCRIPTOR_v2 CUDA_ARRAY_DESCRIPTOR; + +/** + * 3D array descriptor + */ +typedef struct CUDA_ARRAY3D_DESCRIPTOR_st +{ + size_t Width; /**< Width of 3D array */ + size_t Height; /**< Height of 3D array */ + size_t Depth; /**< Depth of 3D array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + unsigned int Flags; /**< Flags */ +} CUDA_ARRAY3D_DESCRIPTOR_v2; +typedef CUDA_ARRAY3D_DESCRIPTOR_v2 CUDA_ARRAY3D_DESCRIPTOR; + +/** + * Indicates that the layered sparse CUDA array or CUDA mipmapped array has a single mip tail region for all layers + */ +#define CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL 0x1 + +/** + * CUDA array sparse properties + */ +typedef struct CUDA_ARRAY_SPARSE_PROPERTIES_st { + struct { + unsigned int width; /**< Width of sparse tile in elements */ + unsigned int height; /**< Height of sparse tile in elements */ + unsigned int depth; /**< Depth of sparse tile in elements */ + } tileExtent; + + /** + * First mip level at which the mip tail begins. + */ + unsigned int miptailFirstLevel; + /** + * Total size of the mip tail. + */ + unsigned long long miptailSize; + /** + * Flags will either be zero or ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL + */ + unsigned int flags; + unsigned int reserved[4]; +} CUDA_ARRAY_SPARSE_PROPERTIES_v1; +typedef CUDA_ARRAY_SPARSE_PROPERTIES_v1 CUDA_ARRAY_SPARSE_PROPERTIES; + + +/** + * CUDA array memory requirements + */ +typedef struct CUDA_ARRAY_MEMORY_REQUIREMENTS_st { + size_t size; /**< Total required memory size */ + size_t alignment; /**< alignment requirement */ + unsigned int reserved[4]; +} CUDA_ARRAY_MEMORY_REQUIREMENTS_v1; +typedef CUDA_ARRAY_MEMORY_REQUIREMENTS_v1 CUDA_ARRAY_MEMORY_REQUIREMENTS; + + +/** + * CUDA Resource descriptor + */ +typedef struct CUDA_RESOURCE_DESC_st +{ + CUresourcetype resType; /**< Resource type */ + + union { + struct { + CUarray hArray; /**< CUDA array */ + } array; + struct { + CUmipmappedArray hMipmappedArray; /**< CUDA mipmapped array */ + } mipmap; + struct { + CUdeviceptr devPtr; /**< Device pointer */ + CUarray_format format; /**< Array format */ + unsigned int numChannels; /**< Channels per array element */ + size_t sizeInBytes; /**< Size in bytes */ + } linear; + struct { + CUdeviceptr devPtr; /**< Device pointer */ + CUarray_format format; /**< Array format */ + unsigned int numChannels; /**< Channels per array element */ + size_t width; /**< Width of the array in elements */ + size_t height; /**< Height of the array in elements */ + size_t pitchInBytes; /**< Pitch between two rows in bytes */ + } pitch2D; + struct { + int reserved[32]; + } reserved; + } res; + + unsigned int flags; /**< Flags (must be zero) */ +} CUDA_RESOURCE_DESC_v1; +typedef CUDA_RESOURCE_DESC_v1 CUDA_RESOURCE_DESC; + +/** + * Texture descriptor + */ +typedef struct CUDA_TEXTURE_DESC_st { + CUaddress_mode addressMode[3]; /**< Address modes */ + CUfilter_mode filterMode; /**< Filter mode */ + unsigned int flags; /**< Flags */ + unsigned int maxAnisotropy; /**< Maximum anisotropy ratio */ + CUfilter_mode mipmapFilterMode; /**< Mipmap filter mode */ + float mipmapLevelBias; /**< Mipmap level bias */ + float minMipmapLevelClamp; /**< Mipmap minimum level clamp */ + float maxMipmapLevelClamp; /**< Mipmap maximum level clamp */ + float borderColor[4]; /**< Border Color */ + int reserved[12]; +} CUDA_TEXTURE_DESC_v1; +typedef CUDA_TEXTURE_DESC_v1 CUDA_TEXTURE_DESC; + +/** + * Resource view format + */ +typedef enum CUresourceViewFormat_enum +{ + CU_RES_VIEW_FORMAT_NONE = 0x00, /**< No resource view format (use underlying resource format) */ + CU_RES_VIEW_FORMAT_UINT_1X8 = 0x01, /**< 1 channel unsigned 8-bit integers */ + CU_RES_VIEW_FORMAT_UINT_2X8 = 0x02, /**< 2 channel unsigned 8-bit integers */ + CU_RES_VIEW_FORMAT_UINT_4X8 = 0x03, /**< 4 channel unsigned 8-bit integers */ + CU_RES_VIEW_FORMAT_SINT_1X8 = 0x04, /**< 1 channel signed 8-bit integers */ + CU_RES_VIEW_FORMAT_SINT_2X8 = 0x05, /**< 2 channel signed 8-bit integers */ + CU_RES_VIEW_FORMAT_SINT_4X8 = 0x06, /**< 4 channel signed 8-bit integers */ + CU_RES_VIEW_FORMAT_UINT_1X16 = 0x07, /**< 1 channel unsigned 16-bit integers */ + CU_RES_VIEW_FORMAT_UINT_2X16 = 0x08, /**< 2 channel unsigned 16-bit integers */ + CU_RES_VIEW_FORMAT_UINT_4X16 = 0x09, /**< 4 channel unsigned 16-bit integers */ + CU_RES_VIEW_FORMAT_SINT_1X16 = 0x0a, /**< 1 channel signed 16-bit integers */ + CU_RES_VIEW_FORMAT_SINT_2X16 = 0x0b, /**< 2 channel signed 16-bit integers */ + CU_RES_VIEW_FORMAT_SINT_4X16 = 0x0c, /**< 4 channel signed 16-bit integers */ + CU_RES_VIEW_FORMAT_UINT_1X32 = 0x0d, /**< 1 channel unsigned 32-bit integers */ + CU_RES_VIEW_FORMAT_UINT_2X32 = 0x0e, /**< 2 channel unsigned 32-bit integers */ + CU_RES_VIEW_FORMAT_UINT_4X32 = 0x0f, /**< 4 channel unsigned 32-bit integers */ + CU_RES_VIEW_FORMAT_SINT_1X32 = 0x10, /**< 1 channel signed 32-bit integers */ + CU_RES_VIEW_FORMAT_SINT_2X32 = 0x11, /**< 2 channel signed 32-bit integers */ + CU_RES_VIEW_FORMAT_SINT_4X32 = 0x12, /**< 4 channel signed 32-bit integers */ + CU_RES_VIEW_FORMAT_FLOAT_1X16 = 0x13, /**< 1 channel 16-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_2X16 = 0x14, /**< 2 channel 16-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_4X16 = 0x15, /**< 4 channel 16-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_1X32 = 0x16, /**< 1 channel 32-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_2X32 = 0x17, /**< 2 channel 32-bit floating point */ + CU_RES_VIEW_FORMAT_FLOAT_4X32 = 0x18, /**< 4 channel 32-bit floating point */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC1 = 0x19, /**< Block compressed 1 */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC2 = 0x1a, /**< Block compressed 2 */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC3 = 0x1b, /**< Block compressed 3 */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC4 = 0x1c, /**< Block compressed 4 unsigned */ + CU_RES_VIEW_FORMAT_SIGNED_BC4 = 0x1d, /**< Block compressed 4 signed */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC5 = 0x1e, /**< Block compressed 5 unsigned */ + CU_RES_VIEW_FORMAT_SIGNED_BC5 = 0x1f, /**< Block compressed 5 signed */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC6H = 0x20, /**< Block compressed 6 unsigned half-float */ + CU_RES_VIEW_FORMAT_SIGNED_BC6H = 0x21, /**< Block compressed 6 signed half-float */ + CU_RES_VIEW_FORMAT_UNSIGNED_BC7 = 0x22 /**< Block compressed 7 */ +} CUresourceViewFormat; + +/** + * Resource view descriptor + */ +typedef struct CUDA_RESOURCE_VIEW_DESC_st +{ + CUresourceViewFormat format; /**< Resource view format */ + size_t width; /**< Width of the resource view */ + size_t height; /**< Height of the resource view */ + size_t depth; /**< Depth of the resource view */ + unsigned int firstMipmapLevel; /**< First defined mipmap level */ + unsigned int lastMipmapLevel; /**< Last defined mipmap level */ + unsigned int firstLayer; /**< First layer index */ + unsigned int lastLayer; /**< Last layer index */ + unsigned int reserved[16]; +} CUDA_RESOURCE_VIEW_DESC_v1; +typedef CUDA_RESOURCE_VIEW_DESC_v1 CUDA_RESOURCE_VIEW_DESC; + +/** + * GPU Direct v3 tokens + */ +typedef struct CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_st { + unsigned long long p2pToken; + unsigned int vaSpaceToken; +} CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_v1; +typedef CUDA_POINTER_ATTRIBUTE_P2P_TOKENS_v1 CUDA_POINTER_ATTRIBUTE_P2P_TOKENS; + +/** +* Access flags that specify the level of access the current context's device has +* on the memory referenced. +*/ +typedef enum CUDA_POINTER_ATTRIBUTE_ACCESS_FLAGS_enum { + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_NONE = 0x0, /**< No access, meaning the device cannot access this memory at all, thus must be staged through accessible memory in order to complete certain operations */ + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_READ = 0x1, /**< Read-only access, meaning writes to this memory are considered invalid accesses and thus return error in that case. */ + CU_POINTER_ATTRIBUTE_ACCESS_FLAG_READWRITE = 0x3 /**< Read-write access, the device has full read-write access to the memory */ +} CUDA_POINTER_ATTRIBUTE_ACCESS_FLAGS; + +/** + * Kernel launch parameters + */ +typedef struct CUDA_LAUNCH_PARAMS_st { + CUfunction function; /**< Kernel to launch */ + unsigned int gridDimX; /**< Width of grid in blocks */ + unsigned int gridDimY; /**< Height of grid in blocks */ + unsigned int gridDimZ; /**< Depth of grid in blocks */ + unsigned int blockDimX; /**< X dimension of each thread block */ + unsigned int blockDimY; /**< Y dimension of each thread block */ + unsigned int blockDimZ; /**< Z dimension of each thread block */ + unsigned int sharedMemBytes; /**< Dynamic shared-memory size per thread block in bytes */ + CUstream hStream; /**< Stream identifier */ + void **kernelParams; /**< Array of pointers to kernel parameters */ +} CUDA_LAUNCH_PARAMS_v1; +typedef CUDA_LAUNCH_PARAMS_v1 CUDA_LAUNCH_PARAMS; + +/** + * External memory handle types + */ +typedef enum CUexternalMemoryHandleType_enum { + /** + * Handle is an opaque file descriptor + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, + /** + * Handle is an opaque shared NT handle + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, + /** + * Handle is an opaque, globally shared handle + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + /** + * Handle is a D3D12 heap object + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, + /** + * Handle is a D3D12 committed resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5, + /** + * Handle is a shared NT handle to a D3D11 resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE = 6, + /** + * Handle is a globally shared handle to a D3D11 resource + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT = 7, + /** + * Handle is an NvSciBuf object + */ + CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF = 8 +} CUexternalMemoryHandleType; + +/** + * Indicates that the external memory object is a dedicated resource + */ +#define CUDA_EXTERNAL_MEMORY_DEDICATED 0x1 + +/** When the \p flags parameter of ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS + * contains this flag, it indicates that signaling an external semaphore object + * should skip performing appropriate memory synchronization operations over all + * the external memory objects that are imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, + * which otherwise are performed by default to ensure data coherency with other + * importers of the same NvSciBuf memory objects. + */ +#define CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC 0x01 + +/** When the \p flags parameter of ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS + * contains this flag, it indicates that waiting on an external semaphore object + * should skip performing appropriate memory synchronization operations over all + * the external memory objects that are imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, + * which otherwise are performed by default to ensure data coherency with other + * importers of the same NvSciBuf memory objects. + */ +#define CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC 0x02 + +/** + * When \p flags of ::cuDeviceGetNvSciSyncAttributes is set to this, + * it indicates that application needs signaler specific NvSciSyncAttr + * to be filled by ::cuDeviceGetNvSciSyncAttributes. + */ +#define CUDA_NVSCISYNC_ATTR_SIGNAL 0x1 + +/** + * When \p flags of ::cuDeviceGetNvSciSyncAttributes is set to this, + * it indicates that application needs waiter specific NvSciSyncAttr + * to be filled by ::cuDeviceGetNvSciSyncAttributes. + */ +#define CUDA_NVSCISYNC_ATTR_WAIT 0x2 +/** + * External memory handle descriptor + */ +typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { + /** + * Type of the handle + */ + CUexternalMemoryHandleType type; + union { + /** + * File descriptor referencing the memory object. Valid + * when type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD + */ + int fd; + /** + * Win32 handle referencing the semaphore object. Valid when + * type is one of the following: + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE + * - ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT + * Exactly one of 'handle' and 'name' must be non-NULL. If + * type is one of the following: + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT + * then 'name' must be NULL. + */ + struct { + /** + * Valid NT handle. Must be NULL if 'name' is non-NULL + */ + void *handle; + /** + * Name of a valid memory object. + * Must be NULL if 'handle' is non-NULL. + */ + const void *name; + } win32; + /** + * A handle representing an NvSciBuf Object. Valid when type + * is ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF + */ + const void *nvSciBufObject; + } handle; + /** + * Size of the memory allocation + */ + unsigned long long size; + /** + * Flags must either be zero or ::CUDA_EXTERNAL_MEMORY_DEDICATED + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_MEMORY_HANDLE_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_HANDLE_DESC_v1 CUDA_EXTERNAL_MEMORY_HANDLE_DESC; + +/** + * External memory buffer descriptor + */ +typedef struct CUDA_EXTERNAL_MEMORY_BUFFER_DESC_st { + /** + * Offset into the memory object where the buffer's base is + */ + unsigned long long offset; + /** + * Size of the buffer + */ + unsigned long long size; + /** + * Flags reserved for future use. Must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_MEMORY_BUFFER_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_BUFFER_DESC_v1 CUDA_EXTERNAL_MEMORY_BUFFER_DESC; + +/** + * External memory mipmap descriptor + */ +typedef struct CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_st { + /** + * Offset into the memory object where the base level of the + * mipmap chain is. + */ + unsigned long long offset; + /** + * Format, dimension and type of base level of the mipmap chain + */ + CUDA_ARRAY3D_DESCRIPTOR arrayDesc; + /** + * Total number of levels in the mipmap chain + */ + unsigned int numLevels; + unsigned int reserved[16]; +} CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_v1; +typedef CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_v1 CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC; + +/** + * External semaphore handle types + */ +typedef enum CUexternalSemaphoreHandleType_enum { + /** + * Handle is an opaque file descriptor + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, + /** + * Handle is an opaque shared NT handle + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, + /** + * Handle is an opaque, globally shared handle + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + /** + * Handle is a shared NT handle referencing a D3D12 fence object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4, + /** + * Handle is a shared NT handle referencing a D3D11 fence object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE = 5, + /** + * Opaque handle to NvSciSync Object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC = 6, + /** + * Handle is a shared NT handle referencing a D3D11 keyed mutex object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX = 7, + /** + * Handle is a globally shared handle referencing a D3D11 keyed mutex object + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT = 8, + /** + * Handle is an opaque file descriptor referencing a timeline semaphore + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD = 9, + /** + * Handle is an opaque shared NT handle referencing a timeline semaphore + */ + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 = 10 +} CUexternalSemaphoreHandleType; + +/** + * External semaphore handle descriptor + */ +typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { + /** + * Type of the handle + */ + CUexternalSemaphoreHandleType type; + union { + /** + * File descriptor referencing the semaphore object. Valid + * when type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD + */ + int fd; + /** + * Win32 handle referencing the semaphore object. Valid when + * type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * Exactly one of 'handle' and 'name' must be non-NULL. If + * type is one of the following: + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * - ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then 'name' must be NULL. + */ + struct { + /** + * Valid NT handle. Must be NULL if 'name' is non-NULL + */ + void *handle; + /** + * Name of a valid synchronization primitive. + * Must be NULL if 'handle' is non-NULL. + */ + const void *name; + } win32; + /** + * Valid NvSciSyncObj. Must be non NULL + */ + const void* nvSciSyncObj; + } handle; + /** + * Flags reserved for the future. Must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_v1 CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; + +/** + * External semaphore signal parameters + */ +typedef struct CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_st { + struct { + /** + * Parameters for fence objects + */ + struct { + /** + * Value of fence to be signaled + */ + unsigned long long value; + } fence; + union { + /** + * Pointer to NvSciSyncFence. Valid if ::CUexternalSemaphoreHandleType + * is of type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC. + */ + void *fence; + unsigned long long reserved; + } nvSciSync; + /** + * Parameters for keyed mutex objects + */ + struct { + /** + * Value of key to release the mutex with + */ + unsigned long long key; + } keyedMutex; + unsigned int reserved[12]; + } params; + /** + * Only when ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS is used to + * signal a ::CUexternalSemaphore of type + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, the valid flag is + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC which indicates + * that while signaling the ::CUexternalSemaphore, no memory synchronization + * operations should be performed for any external memory object imported + * as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. + * For all other types of ::CUexternalSemaphore, flags must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS_v1 CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS; + +/** + * External semaphore wait parameters + */ +typedef struct CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_st { + struct { + /** + * Parameters for fence objects + */ + struct { + /** + * Value of fence to be waited on + */ + unsigned long long value; + } fence; + /** + * Pointer to NvSciSyncFence. Valid if CUexternalSemaphoreHandleType + * is of type CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC. + */ + union { + void *fence; + unsigned long long reserved; + } nvSciSync; + /** + * Parameters for keyed mutex objects + */ + struct { + /** + * Value of key to acquire the mutex with + */ + unsigned long long key; + /** + * Timeout in milliseconds to wait to acquire the mutex + */ + unsigned int timeoutMs; + } keyedMutex; + unsigned int reserved[10]; + } params; + /** + * Only when ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS is used to wait on + * a ::CUexternalSemaphore of type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, + * the valid flag is ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC + * which indicates that while waiting for the ::CUexternalSemaphore, no memory + * synchronization operations should be performed for any external memory + * object imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. + * For all other types of ::CUexternalSemaphore, flags must be zero. + */ + unsigned int flags; + unsigned int reserved[16]; +} CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_v1; +typedef CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS_v1 CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS; + +/** + * Semaphore signal node parameters + */ +typedef struct CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_st { + CUexternalSemaphore* extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS* paramsArray; /**< Array of external semaphore signal parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in extSemArray and paramsArray. */ +} CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v1; +typedef CUDA_EXT_SEM_SIGNAL_NODE_PARAMS_v1 CUDA_EXT_SEM_SIGNAL_NODE_PARAMS; + +/** + * Semaphore wait node parameters + */ +typedef struct CUDA_EXT_SEM_WAIT_NODE_PARAMS_st { + CUexternalSemaphore* extSemArray; /**< Array of external semaphore handles. */ + const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS* paramsArray; /**< Array of external semaphore wait parameters. */ + unsigned int numExtSems; /**< Number of handles and parameters supplied in extSemArray and paramsArray. */ +} CUDA_EXT_SEM_WAIT_NODE_PARAMS_v1; +typedef CUDA_EXT_SEM_WAIT_NODE_PARAMS_v1 CUDA_EXT_SEM_WAIT_NODE_PARAMS; + +typedef unsigned long long CUmemGenericAllocationHandle_v1; +typedef CUmemGenericAllocationHandle_v1 CUmemGenericAllocationHandle; + +/** + * Flags for specifying particular handle types + */ +typedef enum CUmemAllocationHandleType_enum { + CU_MEM_HANDLE_TYPE_NONE = 0x0, /**< Does not allow any export mechanism. > */ + CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = 0x1, /**< Allows a file descriptor to be used for exporting. Permitted only on POSIX systems. (int) */ + CU_MEM_HANDLE_TYPE_WIN32 = 0x2, /**< Allows a Win32 NT handle to be used for exporting. (HANDLE) */ + CU_MEM_HANDLE_TYPE_WIN32_KMT = 0x4, /**< Allows a Win32 KMT handle to be used for exporting. (D3DKMT_HANDLE) */ + CU_MEM_HANDLE_TYPE_MAX = 0x7FFFFFFF +} CUmemAllocationHandleType; + +/** + * Specifies the memory protection flags for mapping. + */ +typedef enum CUmemAccess_flags_enum { + CU_MEM_ACCESS_FLAGS_PROT_NONE = 0x0, /**< Default, make the address range not accessible */ + CU_MEM_ACCESS_FLAGS_PROT_READ = 0x1, /**< Make the address range read accessible */ + CU_MEM_ACCESS_FLAGS_PROT_READWRITE = 0x3, /**< Make the address range read-write accessible */ + CU_MEM_ACCESS_FLAGS_PROT_MAX = 0x7FFFFFFF +} CUmemAccess_flags; + +/** + * Specifies the type of location + */ +typedef enum CUmemLocationType_enum { + CU_MEM_LOCATION_TYPE_INVALID = 0x0, + CU_MEM_LOCATION_TYPE_DEVICE = 0x1, /**< Location is a device location, thus id is a device ordinal */ + CU_MEM_LOCATION_TYPE_MAX = 0x7FFFFFFF +} CUmemLocationType; + +/** +* Defines the allocation types available +*/ +typedef enum CUmemAllocationType_enum { + CU_MEM_ALLOCATION_TYPE_INVALID = 0x0, + + /** This allocation type is 'pinned', i.e. cannot migrate from its current + * location while the application is actively using it + */ + CU_MEM_ALLOCATION_TYPE_PINNED = 0x1, + CU_MEM_ALLOCATION_TYPE_MAX = 0x7FFFFFFF +} CUmemAllocationType; + +/** +* Flag for requesting different optimal and required granularities for an allocation. +*/ +typedef enum CUmemAllocationGranularity_flags_enum { + CU_MEM_ALLOC_GRANULARITY_MINIMUM = 0x0, /**< Minimum required granularity for allocation */ + CU_MEM_ALLOC_GRANULARITY_RECOMMENDED = 0x1 /**< Recommended granularity for allocation for best performance */ +} CUmemAllocationGranularity_flags; + +/** + * Sparse subresource types + */ +typedef enum CUarraySparseSubresourceType_enum { + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL = 0, + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL = 1 +} CUarraySparseSubresourceType; + +/** + * Memory operation types + */ +typedef enum CUmemOperationType_enum { + CU_MEM_OPERATION_TYPE_MAP = 1, + CU_MEM_OPERATION_TYPE_UNMAP = 2 +} CUmemOperationType; + +/** + * Memory handle types + */ +typedef enum CUmemHandleType_enum { + CU_MEM_HANDLE_TYPE_GENERIC = 0 +} CUmemHandleType; + +/** + * Specifies the CUDA array or CUDA mipmapped array memory mapping information + */ +typedef struct CUarrayMapInfo_st { + CUresourcetype resourceType; /**< Resource type */ + + union { + CUmipmappedArray mipmap; + CUarray array; + } resource; + + CUarraySparseSubresourceType subresourceType; /**< Sparse subresource type */ + + union { + struct { + unsigned int level; /**< For CUDA mipmapped arrays must a valid mipmap level. For CUDA arrays must be zero */ + unsigned int layer; /**< For CUDA layered arrays must be a valid layer index. Otherwise, must be zero */ + unsigned int offsetX; /**< Starting X offset in elements */ + unsigned int offsetY; /**< Starting Y offset in elements */ + unsigned int offsetZ; /**< Starting Z offset in elements */ + unsigned int extentWidth; /**< Width in elements */ + unsigned int extentHeight; /**< Height in elements */ + unsigned int extentDepth; /**< Depth in elements */ + } sparseLevel; + struct { + unsigned int layer; /**< For CUDA layered arrays must be a valid layer index. Otherwise, must be zero */ + unsigned long long offset; /**< Offset within mip tail */ + unsigned long long size; /**< Extent in bytes */ + } miptail; + } subresource; + + CUmemOperationType memOperationType; /**< Memory operation type */ + CUmemHandleType memHandleType; /**< Memory handle type */ + + union { + CUmemGenericAllocationHandle memHandle; + } memHandle; + + unsigned long long offset; /**< Offset within the memory */ + unsigned int deviceBitMask; /**< Device ordinal bit mask */ + unsigned int flags; /**< flags for future use, must be zero now. */ + unsigned int reserved[2]; /**< Reserved for future use, must be zero now. */ +} CUarrayMapInfo_v1; +typedef CUarrayMapInfo_v1 CUarrayMapInfo; + +/** + * Specifies a memory location. + */ +typedef struct CUmemLocation_st { + CUmemLocationType type; /**< Specifies the location type, which modifies the meaning of id. */ + int id; /**< identifier for a given this location's ::CUmemLocationType. */ +} CUmemLocation_v1; +typedef CUmemLocation_v1 CUmemLocation; + +/** + * Specifies compression attribute for an allocation. + */ +typedef enum CUmemAllocationCompType_enum { + CU_MEM_ALLOCATION_COMP_NONE = 0x0, /**< Allocating non-compressible memory */ + CU_MEM_ALLOCATION_COMP_GENERIC = 0x1 /**< Allocating compressible memory */ +} CUmemAllocationCompType; + +/** + * This flag if set indicates that the memory will be used as a tile pool. + */ +#define CU_MEM_CREATE_USAGE_TILE_POOL 0x1 + +/** +* Specifies the allocation properties for a allocation. +*/ +typedef struct CUmemAllocationProp_st { + /** Allocation type */ + CUmemAllocationType type; + /** requested ::CUmemAllocationHandleType */ + CUmemAllocationHandleType requestedHandleTypes; + /** Location of allocation */ + CUmemLocation location; + /** + * Windows-specific POBJECT_ATTRIBUTES required when + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This object atributes structure + * includes security attributes that define + * the scope of which exported allocations may be tranferred to other + * processes. In all other cases, this field is required to be zero. + */ + void *win32HandleMetaData; + struct { + /** + * Allocation hint for requesting compressible memory. + * On devices that support Compute Data Compression, compressible + * memory can be used to accelerate accesses to data with unstructured + * sparsity and other compressible data patterns. Applications are + * expected to query allocation property of the handle obtained with + * ::cuMemCreate using ::cuMemGetAllocationPropertiesFromHandle to + * validate if the obtained allocation is compressible or not. Note that + * compressed memory may not be mappable on all devices. + */ + unsigned char compressionType; + unsigned char gpuDirectRDMACapable; + /** Bitmask indicating intended usage for this allocation */ + unsigned short usage; + unsigned char reserved[4]; + } allocFlags; +} CUmemAllocationProp_v1; +typedef CUmemAllocationProp_v1 CUmemAllocationProp; + +/** + * Memory access descriptor + */ +typedef struct CUmemAccessDesc_st { + CUmemLocation location; /**< Location on which the request is to change it's accessibility */ + CUmemAccess_flags flags; /**< ::CUmemProt accessibility flags to set on the request */ +} CUmemAccessDesc_v1; +typedef CUmemAccessDesc_v1 CUmemAccessDesc; + +typedef enum CUgraphExecUpdateResult_enum { + CU_GRAPH_EXEC_UPDATE_SUCCESS = 0x0, /**< The update succeeded */ + CU_GRAPH_EXEC_UPDATE_ERROR = 0x1, /**< The update failed for an unexpected reason which is described in the return value of the function */ + CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED = 0x2, /**< The update failed because the topology changed */ + CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED = 0x3, /**< The update failed because a node type changed */ + CU_GRAPH_EXEC_UPDATE_ERROR_FUNCTION_CHANGED = 0x4, /**< The update failed because the function of a kernel node changed (CUDA driver < 11.2) */ + CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED = 0x5, /**< The update failed because the parameters changed in a way that is not supported */ + CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED = 0x6, /**< The update failed because something about the node is not supported */ + CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE = 0x7, /**< The update failed because the function of a kernel node changed in an unsupported way */ + CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED = 0x8 /**< The update failed because the node attributes changed in a way that is not supported */ +} CUgraphExecUpdateResult; + +/** + * CUDA memory pool attributes + */ +typedef enum CUmemPool_attribute_enum { + /** + * (value type = int) + * Allow cuMemAllocAsync to use memory asynchronously freed + * in another streams as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + */ + CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES = 1, + + /** + * (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + */ + CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, + + /** + * (value type = int) + * Allow cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by cuFreeAsync (default enabled). + */ + CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES, + + /** + * (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + */ + CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, + + /** + * (value type = cuuint64_t) + * Amount of backing memory currently allocated for the mempool. + */ + CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of backing memory allocated for the mempool since the + * last time it was reset. High watermark can only be reset to zero. + */ + CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, + + /** + * (value type = cuuint64_t) + * Amount of memory from the pool that is currently in use by the application. + */ + CU_MEMPOOL_ATTR_USED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of the amount of memory from the pool that was in use by the application since + * the last time it was reset. High watermark can only be reset to zero. + */ + CU_MEMPOOL_ATTR_USED_MEM_HIGH +} CUmemPool_attribute; + +/** + * Specifies the properties of allocations made from the pool. + */ +typedef struct CUmemPoolProps_st { + CUmemAllocationType allocType; /**< Allocation type. Currently must be specified as CU_MEM_ALLOCATION_TYPE_PINNED */ + CUmemAllocationHandleType handleTypes; /**< Handle types that will be supported by allocations from the pool. */ + CUmemLocation location; /**< Location where allocations should reside. */ + /** + * Windows-specific LPSECURITYATTRIBUTES required when + * ::CU_MEM_HANDLE_TYPE_WIN32 is specified. This security attribute defines + * the scope of which exported allocations may be tranferred to other + * processes. In all other cases, this field is required to be zero. + */ + void *win32SecurityAttributes; + unsigned char reserved[64]; /**< reserved for future use, must be 0 */ +} CUmemPoolProps_v1; +typedef CUmemPoolProps_v1 CUmemPoolProps; + +/** + * Opaque data for exporting a pool allocation + */ +typedef struct CUmemPoolPtrExportData_st { + unsigned char reserved[64]; +} CUmemPoolPtrExportData_v1; +typedef CUmemPoolPtrExportData_v1 CUmemPoolPtrExportData; + +/** + * Memory allocation node parameters + */ +typedef struct CUDA_MEM_ALLOC_NODE_PARAMS_st { + /** + * in: location where the allocation should reside (specified in ::location). + * ::handleTypes must be ::CU_MEM_HANDLE_TYPE_NONE. IPC is not supported. + */ + CUmemPoolProps poolProps; + const CUmemAccessDesc *accessDescs; /**< in: array of memory access descriptors. Used to describe peer GPU access */ + size_t accessDescCount; /**< in: number of memory access descriptors. Must not exceed the number of GPUs. */ + size_t bytesize; /**< in: size in bytes of the requested allocation */ + CUdeviceptr dptr; /**< out: address of the allocation returned by CUDA */ +} CUDA_MEM_ALLOC_NODE_PARAMS; + +typedef enum CUgraphMem_attribute_enum { + /** + * (value type = cuuint64_t) + * Amount of memory, in bytes, currently associated with graphs + */ + CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + */ + CU_GRAPH_MEM_ATTR_USED_MEM_HIGH, + + /** + * (value type = cuuint64_t) + * Amount of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + */ + CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT, + + /** + * (value type = cuuint64_t) + * High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + */ + CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH +} CUgraphMem_attribute; + +/** + * If set, each kernel launched as part of ::cuLaunchCooperativeKernelMultiDevice only + * waits for prior work in the stream corresponding to that GPU to complete before the + * kernel begins execution. + */ +#define CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_PRE_LAUNCH_SYNC 0x01 + +/** + * If set, any subsequent work pushed in a stream that participated in a call to + * ::cuLaunchCooperativeKernelMultiDevice will only wait for the kernel launched on + * the GPU corresponding to that stream to complete before it begins execution. + */ +#define CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_POST_LAUNCH_SYNC 0x02 + +/** + * If set, the CUDA array is a collection of layers, where each layer is either a 1D + * or a 2D array and the Depth member of CUDA_ARRAY3D_DESCRIPTOR specifies the number + * of layers, not the depth of a 3D array. + */ +#define CUDA_ARRAY3D_LAYERED 0x01 + +/** + * Deprecated, use CUDA_ARRAY3D_LAYERED + */ +#define CUDA_ARRAY3D_2DARRAY 0x01 + +/** + * This flag must be set in order to bind a surface reference + * to the CUDA array + */ +#define CUDA_ARRAY3D_SURFACE_LDST 0x02 + +/** + * If set, the CUDA array is a collection of six 2D arrays, representing faces of a cube. The + * width of such a CUDA array must be equal to its height, and Depth must be six. + * If ::CUDA_ARRAY3D_LAYERED flag is also set, then the CUDA array is a collection of cubemaps + * and Depth must be a multiple of six. + */ +#define CUDA_ARRAY3D_CUBEMAP 0x04 + +/** + * This flag must be set in order to perform texture gather operations + * on a CUDA array. + */ +#define CUDA_ARRAY3D_TEXTURE_GATHER 0x08 + +/** + * This flag if set indicates that the CUDA + * array is a DEPTH_TEXTURE. + */ +#define CUDA_ARRAY3D_DEPTH_TEXTURE 0x10 + +/** + * This flag indicates that the CUDA array may be bound as a color target + * in an external graphics API + */ +#define CUDA_ARRAY3D_COLOR_ATTACHMENT 0x20 + +/** + * This flag if set indicates that the CUDA array or CUDA mipmapped array + * is a sparse CUDA array or CUDA mipmapped array respectively + */ +#define CUDA_ARRAY3D_SPARSE 0x40 + + +/** + * This flag if set indicates that the CUDA array or CUDA mipmapped array + * will allow deferred memory mapping + */ +#define CUDA_ARRAY3D_DEFERRED_MAPPING 0x80 + + +/** + * Override the texref format with a format inferred from the array. + * Flag for ::cuTexRefSetArray() + */ +#define CU_TRSA_OVERRIDE_FORMAT 0x01 + +/** + * Read the texture as integers rather than promoting the values to floats + * in the range [0,1]. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_READ_AS_INTEGER 0x01 + +/** + * Use normalized texture coordinates in the range [0,1) instead of [0,dim). + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_NORMALIZED_COORDINATES 0x02 + +/** + * Perform sRGB->linear conversion during texture read. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_SRGB 0x10 + + /** + * Disable any trilinear filtering optimizations. + * Flag for ::cuTexRefSetFlags() and ::cuTexObjectCreate() + */ +#define CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION 0x20 + +/** + * Enable seamless cube map filtering. + * Flag for ::cuTexObjectCreate() + */ +#define CU_TRSF_SEAMLESS_CUBEMAP 0x40 + +/** + * End of array terminator for the \p extra parameter to + * ::cuLaunchKernel + */ +#define CU_LAUNCH_PARAM_END ((void*)0x00) + +/** + * Indicator that the next value in the \p extra parameter to + * ::cuLaunchKernel will be a pointer to a buffer containing all kernel + * parameters used for launching kernel \p f. This buffer needs to + * honor all alignment/padding requirements of the individual parameters. + * If ::CU_LAUNCH_PARAM_BUFFER_SIZE is not also specified in the + * \p extra array, then ::CU_LAUNCH_PARAM_BUFFER_POINTER will have no + * effect. + */ +#define CU_LAUNCH_PARAM_BUFFER_POINTER ((void*)0x01) + +/** + * Indicator that the next value in the \p extra parameter to + * ::cuLaunchKernel will be a pointer to a size_t which contains the + * size of the buffer specified with ::CU_LAUNCH_PARAM_BUFFER_POINTER. + * It is required that ::CU_LAUNCH_PARAM_BUFFER_POINTER also be specified + * in the \p extra array if the value associated with + * ::CU_LAUNCH_PARAM_BUFFER_SIZE is not zero. + */ +#define CU_LAUNCH_PARAM_BUFFER_SIZE ((void*)0x02) + +/** + * For texture references loaded into the module, use default texunit from + * texture reference. + */ +#define CU_PARAM_TR_DEFAULT -1 + +/** + * Device that represents the CPU + */ +#define CU_DEVICE_CPU ((CUdevice)-1) + +/** + * Device that represents an invalid device + */ +#define CU_DEVICE_INVALID ((CUdevice)-2) + +/** + * Bitmasks for ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS + */ +typedef enum CUflushGPUDirectRDMAWritesOptions_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_OPTION_HOST = 1<<0, /**< ::cuFlushGPUDirectRDMAWrites() and its CUDA Runtime API counterpart are supported on the device. */ + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_OPTION_MEMOPS = 1<<1 /**< The ::CU_STREAM_WAIT_VALUE_FLUSH flag and the ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES MemOp are supported on the device. */ +} CUflushGPUDirectRDMAWritesOptions; + +/** + * Platform native ordering for GPUDirect RDMA writes + */ +typedef enum CUGPUDirectRDMAWritesOrdering_enum { + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_NONE = 0, /**< The device does not natively support ordering of remote writes. ::cuFlushGPUDirectRDMAWrites() can be leveraged if supported. */ + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_OWNER = 100, /**< Natively, the device can consistently consume remote writes, although other CUDA devices may not. */ + CU_GPU_DIRECT_RDMA_WRITES_ORDERING_ALL_DEVICES = 200 /**< Any CUDA device in the system can consistently consume remote writes to this device. */ +} CUGPUDirectRDMAWritesOrdering; + +/** + * The scopes for ::cuFlushGPUDirectRDMAWrites + */ +typedef enum CUflushGPUDirectRDMAWritesScope_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_OWNER = 100, /**< Blocks until remote writes are visible to the CUDA device context owning the data. */ + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TO_ALL_DEVICES = 200 /**< Blocks until remote writes are visible to all CUDA device contexts. */ +} CUflushGPUDirectRDMAWritesScope; + +/** + * The targets for ::cuFlushGPUDirectRDMAWrites + */ +typedef enum CUflushGPUDirectRDMAWritesTarget_enum { + CU_FLUSH_GPU_DIRECT_RDMA_WRITES_TARGET_CURRENT_CTX = 0 /**< Sets the target for ::cuFlushGPUDirectRDMAWrites() to the currently active CUDA device context. */ +} CUflushGPUDirectRDMAWritesTarget; + +/** + * The additional write options for ::cuGraphDebugDotPrint + */ +typedef enum CUgraphDebugDot_flags_enum { + CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE = 1<<0, /** Output all debug data as if every debug flag is enabled */ + CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES = 1<<1, /** Use CUDA Runtime structures for output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS = 1<<2, /** Adds CUDA_KERNEL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS = 1<<3, /** Adds CUDA_MEMCPY3D values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS = 1<<4, /** Adds CUDA_MEMSET_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS = 1<<5, /** Adds CUDA_HOST_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS = 1<<6, /** Adds CUevent handle from record and wait nodes to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS = 1<<7, /** Adds CUDA_EXT_SEM_SIGNAL_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS = 1<<8, /** Adds CUDA_EXT_SEM_WAIT_NODE_PARAMS values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES = 1<<9, /** Adds CUkernelNodeAttrValue values to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES = 1<<10, /** Adds node handles and every kernel function handle to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS = 1<<11, /** Adds memory alloc node parameters to output */ + CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS = 1<<12 /** Adds memory free node parameters to output */ +} CUgraphDebugDot_flags; + +/** + * Flags for user objects for graphs + */ +typedef enum CUuserObject_flags_enum { + CU_USER_OBJECT_NO_DESTRUCTOR_SYNC = 1 /**< Indicates the destructor execution is not synchronized by any CUDA handle. */ +} CUuserObject_flags; + +/** + * Flags for retaining user object references for graphs + */ +typedef enum CUuserObjectRetain_flags_enum { + CU_GRAPH_USER_OBJECT_MOVE = 1 /**< Transfer references from the caller rather than creating new references. */ +} CUuserObjectRetain_flags; + +/** + * Flags for instantiating a graph + */ +typedef enum CUgraphInstantiate_flags_enum { + CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH = 1 /**< Automatically free memory allocated in a graph before relaunching. */ +} CUgraphInstantiate_flags; + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +/** @} */ /* END CUDA_TYPES */ + +#if defined(__GNUC__) + #if defined(__CUDA_API_PUSH_VISIBILITY_DEFAULT) + #pragma GCC visibility push(default) + #endif +#endif + +#ifdef _WIN32 +#define CUDAAPI __stdcall +#else +#define CUDAAPI +#endif + +/** + * \defgroup CUDA_ERROR Error Handling + * + * ___MANBRIEF___ error handling functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the error handling functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Gets the string description of an error code + * + * Sets \p *pStr to the address of a NULL-terminated string description + * of the error code \p error. + * If the error code is not recognized, ::CUDA_ERROR_INVALID_VALUE + * will be returned and \p *pStr will be set to the NULL address. + * + * \param error - Error code to convert to string + * \param pStr - Address of the string pointer. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::CUresult, + * ::cudaGetErrorString + */ +CUresult CUDAAPI cuGetErrorString(CUresult error, const char **pStr); + +/** + * \brief Gets the string representation of an error code enum name + * + * Sets \p *pStr to the address of a NULL-terminated string representation + * of the name of the enum error code \p error. + * If the error code is not recognized, ::CUDA_ERROR_INVALID_VALUE + * will be returned and \p *pStr will be set to the NULL address. + * + * \param error - Error code to convert to string + * \param pStr - Address of the string pointer. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::CUresult, + * ::cudaGetErrorName + */ +CUresult CUDAAPI cuGetErrorName(CUresult error, const char **pStr); + +/** @} */ /* END CUDA_ERROR */ + +/** + * \defgroup CUDA_INITIALIZE Initialization + * + * ___MANBRIEF___ initialization functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the initialization functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Initialize the CUDA driver API + * + * Initializes the driver API and must be called before any other function from + * the driver API. Currently, the \p Flags parameter must be 0. If ::cuInit() + * has not been called, any function from the driver API will return + * ::CUDA_ERROR_NOT_INITIALIZED. + * + * \param Flags - Initialization flag for CUDA. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_SYSTEM_DRIVER_MISMATCH, + * ::CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE + * \notefnerr + */ +CUresult CUDAAPI cuInit(unsigned int Flags); + +/** @} */ /* END CUDA_INITIALIZE */ + +/** + * \defgroup CUDA_VERSION Version Management + * + * ___MANBRIEF___ version management functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the version management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns the latest CUDA version supported by driver + * + * Returns in \p *driverVersion the version of CUDA supported by + * the driver. The version is returned as + * (1000 × major + 10 × minor). For example, CUDA 9.2 + * would be represented by 9020. + * + * This function automatically returns ::CUDA_ERROR_INVALID_VALUE if + * \p driverVersion is NULL. + * + * \param driverVersion - Returns the CUDA driver version + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cudaDriverGetVersion, + * ::cudaRuntimeGetVersion + */ +CUresult CUDAAPI cuDriverGetVersion(int *driverVersion); + +/** @} */ /* END CUDA_VERSION */ + +/** + * \defgroup CUDA_DEVICE Device Management + * + * ___MANBRIEF___ device management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the device management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns a handle to a compute device + * + * Returns in \p *device a device handle given an ordinal in the range [0, + * ::cuDeviceGetCount()-1]. + * + * \param device - Returned device handle + * \param ordinal - Device number to get handle for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGetLuid, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport + */ +CUresult CUDAAPI cuDeviceGet(CUdevice *device, int ordinal); + +/** + * \brief Returns the number of compute-capable devices + * + * Returns in \p *count the number of devices with compute capability greater + * than or equal to 2.0 that are available for execution. If there is no such + * device, ::cuDeviceGetCount() returns 0. + * + * \param count - Returned number of compute-capable devices + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceCount + */ +CUresult CUDAAPI cuDeviceGetCount(int *count); + +/** + * \brief Returns an identifer string for the device + * + * Returns an ASCII string identifying the device \p dev in the NULL-terminated + * string pointed to by \p name. \p len specifies the maximum length of the + * string that may be returned. + * + * \param name - Returned identifier string for the device + * \param len - Maximum length of string to store in \p name + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetUuid, + * ::cuDeviceGetLuid, + * ::cuDeviceGetCount, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetName(char *name, int len, CUdevice dev); + +/** + * \brief Return an UUID for the device + * + * Note there is a later version of this API, ::cuDeviceGetUuid_v2. It will + * supplant this version in 12.0, which is retained for minor version compatibility. + * + * Returns 16-octets identifing the device \p dev in the structure + * pointed by the \p uuid. + * + * \param uuid - Returned UUID + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetUuid_v2 + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetUuid(CUuuid *uuid, CUdevice dev); + +/** + * \brief Return an UUID for the device (11.4+) + * + * Returns 16-octets identifing the device \p dev in the structure + * pointed by the \p uuid. If the device is in MIG mode, returns its + * MIG UUID which uniquely identifies the subscribed MIG compute instance. + * + * \param uuid - Returned UUID + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetLuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetUuid_v2(CUuuid *uuid, CUdevice dev); + +/** + * \brief Return an LUID and device node mask for the device + * + * Return identifying information (\p luid and \p deviceNodeMask) to allow + * matching device with graphics APIs. + * + * \param luid - Returned LUID + * \param deviceNodeMask - Returned device node mask + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetLuid(char *luid, unsigned int *deviceNodeMask, CUdevice dev); + +/** + * \brief Returns the total amount of memory on the device + * + * Returns in \p *bytes the total amount of memory available on the device + * \p dev in bytes. + * + * \param bytes - Returned memory available on device in bytes + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaMemGetInfo + */ +CUresult CUDAAPI cuDeviceTotalMem(size_t *bytes, CUdevice dev); + +/** + * \brief Returns the maximum number of elements allocatable in a 1D linear texture for a given texture element size. + * + * Returns in \p maxWidthInElements the maximum number of texture elements allocatable in a 1D linear texture + * for given \p format and \p numChannels. + * + * \param maxWidthInElements - Returned maximum number of texture elements allocatable for given \p format and \p numChannels. + * \param format - Texture format. + * \param numChannels - Number of channels per texture element. + * \param dev - Device handle. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cudaMemGetInfo, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetTexture1DLinearMaxWidth(size_t *maxWidthInElements, CUarray_format format, unsigned numChannels, CUdevice dev); + +/** + * \brief Returns information about the device + * + * Returns in \p *pi the integer value of the attribute \p attrib on device + * \p dev. The supported attributes are: + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK: Maximum number of threads per + * block; + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X: Maximum x-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Y: Maximum y-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_Z: Maximum z-dimension of a block + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X: Maximum x-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y: Maximum y-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_Z: Maximum z-dimension of a grid + * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK: Maximum amount of + * shared memory available to a thread block in bytes + * - ::CU_DEVICE_ATTRIBUTE_TOTAL_CONSTANT_MEMORY: Memory available on device for + * __constant__ variables in a CUDA C kernel in bytes + * - ::CU_DEVICE_ATTRIBUTE_WARP_SIZE: Warp size in threads + * - ::CU_DEVICE_ATTRIBUTE_MAX_PITCH: Maximum pitch in bytes allowed by the + * memory copy functions that involve memory regions allocated through + * ::cuMemAllocPitch() + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH: Maximum 1D + * texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH: Maximum width + * for a 1D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH: Maximum + * mipmapped 1D texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_WIDTH: Maximum 2D + * texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_HEIGHT: Maximum 2D + * texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH: Maximum width + * for a 2D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT: Maximum height + * for a 2D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH: Maximum pitch + * in bytes for a 2D texture bound to linear memory + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_WIDTH: Maximum + * mipmapped 2D texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_MIPMAPPED_HEIGHT: Maximum + * mipmapped 2D texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH: Maximum 3D + * texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT: Maximum 3D + * texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH: Maximum 3D + * texture depth + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_WIDTH_ALTERNATE: + * Alternate maximum 3D texture width, 0 if no alternate + * maximum 3D texture size is supported + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_HEIGHT_ALTERNATE: + * Alternate maximum 3D texture height, 0 if no alternate + * maximum 3D texture size is supported + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE3D_DEPTH_ALTERNATE: + * Alternate maximum 3D texture depth, 0 if no alternate + * maximum 3D texture size is supported + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_WIDTH: + * Maximum cubemap texture width or height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_WIDTH: + * Maximum 1D layered texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LAYERED_LAYERS: + * Maximum layers in a 1D layered texture + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_WIDTH: + * Maximum 2D layered texture width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_HEIGHT: + * Maximum 2D layered texture height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LAYERED_LAYERS: + * Maximum layers in a 2D layered texture + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_WIDTH: + * Maximum cubemap layered texture width or height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURECUBEMAP_LAYERED_LAYERS: + * Maximum layers in a cubemap layered texture + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_WIDTH: + * Maximum 1D surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_WIDTH: + * Maximum 2D surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_HEIGHT: + * Maximum 2D surface height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_WIDTH: + * Maximum 3D surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_HEIGHT: + * Maximum 3D surface height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE3D_DEPTH: + * Maximum 3D surface depth + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_WIDTH: + * Maximum 1D layered surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE1D_LAYERED_LAYERS: + * Maximum layers in a 1D layered surface + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_WIDTH: + * Maximum 2D layered surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_HEIGHT: + * Maximum 2D layered surface height + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACE2D_LAYERED_LAYERS: + * Maximum layers in a 2D layered surface + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_WIDTH: + * Maximum cubemap surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_WIDTH: + * Maximum cubemap layered surface width + * - ::CU_DEVICE_ATTRIBUTE_MAXIMUM_SURFACECUBEMAP_LAYERED_LAYERS: + * Maximum layers in a cubemap layered surface + * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK: Maximum number of 32-bit + * registers available to a thread block + * - ::CU_DEVICE_ATTRIBUTE_CLOCK_RATE: The typical clock frequency in kilohertz + * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT: Alignment requirement; texture + * base addresses aligned to ::textureAlign bytes do not need an offset + * applied to texture fetches + * - ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT: Pitch alignment requirement + * for 2D texture references bound to pitched memory + * - ::CU_DEVICE_ATTRIBUTE_GPU_OVERLAP: 1 if the device can concurrently copy + * memory between host and device while executing a kernel, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT: Number of multiprocessors on + * the device + * - ::CU_DEVICE_ATTRIBUTE_KERNEL_EXEC_TIMEOUT: 1 if there is a run time limit + * for kernels executed on the device, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_INTEGRATED: 1 if the device is integrated with the + * memory subsystem, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_CAN_MAP_HOST_MEMORY: 1 if the device can map host + * memory into the CUDA address space, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE: Compute mode that device is currently + * in. Available modes are as follows: + * - ::CU_COMPUTEMODE_DEFAULT: Default mode - Device is not restricted and + * can have multiple CUDA contexts present at a single time. + * - ::CU_COMPUTEMODE_PROHIBITED: Compute-prohibited mode - Device is + * prohibited from creating new CUDA contexts. + * - ::CU_COMPUTEMODE_EXCLUSIVE_PROCESS: Compute-exclusive-process mode - Device + * can have only one context used by a single process at a time. + * - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_KERNELS: 1 if the device supports + * executing multiple kernels within the same context simultaneously, or 0 if + * not. It is not guaranteed that multiple kernels will be resident + * on the device concurrently so this feature should not be relied upon for + * correctness. + * - ::CU_DEVICE_ATTRIBUTE_ECC_ENABLED: 1 if error correction is enabled on the + * device, 0 if error correction is disabled or not supported by the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_BUS_ID: PCI bus identifier of the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID: PCI device (also known as slot) identifier + * of the device + * - ::CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID: PCI domain identifier of the device + * - ::CU_DEVICE_ATTRIBUTE_TCC_DRIVER: 1 if the device is using a TCC driver. TCC + * is only available on Tesla hardware running Windows Vista or later + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE: Peak memory clock frequency in kilohertz + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH: Global memory bus width in bits + * - ::CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE: Size of L2 cache in bytes. 0 if the device doesn't have L2 cache + * - ::CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR: Maximum resident threads per multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING: 1 if the device shares a unified address space with + * the host, or 0 if not + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR: Major compute capability version number + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR: Minor compute capability version number + * - ::CU_DEVICE_ATTRIBUTE_GLOBAL_L1_CACHE_SUPPORTED: 1 if device supports caching globals + * in L1 cache, 0 if caching globals in L1 cache is not supported by the device + * - ::CU_DEVICE_ATTRIBUTE_LOCAL_L1_CACHE_SUPPORTED: 1 if device supports caching locals + * in L1 cache, 0 if caching locals in L1 cache is not supported by the device + * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR: Maximum amount of + * shared memory available to a multiprocessor in bytes; this amount is shared + * by all thread blocks simultaneously resident on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR: Maximum number of 32-bit + * registers available to a multiprocessor; this number is shared by all thread + * blocks simultaneously resident on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY: 1 if device supports allocating managed memory + * on this system, 0 if allocating managed memory is not supported by the device on this system. + * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD: 1 if device is on a multi-GPU board, 0 if not. + * - ::CU_DEVICE_ATTRIBUTE_MULTI_GPU_BOARD_GROUP_ID: Unique identifier for a group of devices + * associated with the same board. Devices on the same multi-GPU board will share the same identifier. + * - ::CU_DEVICE_ATTRIBUTE_HOST_NATIVE_ATOMIC_SUPPORTED: 1 if Link between the device and the host + * supports native atomic operations. + * - ::CU_DEVICE_ATTRIBUTE_SINGLE_TO_DOUBLE_PRECISION_PERF_RATIO: Ratio of single precision performance + * (in floating-point operations per second) to double precision performance. + * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS: Device suppports coherently accessing + * pageable memory without calling cudaHostRegister on it. + * - ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS: Device can coherently access managed memory + * concurrently with the CPU. + * - ::CU_DEVICE_ATTRIBUTE_COMPUTE_PREEMPTION_SUPPORTED: Device supports Compute Preemption. + * - ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM: Device can access host registered + * memory at the same virtual address as the CPU. + * - ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN: The maximum per block shared memory size + * suported on this device. This is the maximum value that can be opted into when using the cuFuncSetAttribute() call. + * For more details see ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES + * - ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES: Device accesses pageable memory via the host's + * page tables. + * - ::CU_DEVICE_ATTRIBUTE_DIRECT_MANAGED_MEM_ACCESS_FROM_HOST: The host can directly access managed memory on the device without migration. + * - ::CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED: Device supports virtual memory management APIs like ::cuMemAddressReserve, ::cuMemCreate, ::cuMemMap and related APIs + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED: Device supports exporting memory to a posix file descriptor with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 NT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_WIN32_KMT_HANDLE_SUPPORTED: Device supports exporting memory to a Win32 KMT handle with ::cuMemExportToShareableHandle, if requested via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_BLOCKS_PER_MULTIPROCESSOR: Maximum number of thread blocks that can reside on a multiprocessor + * - ::CU_DEVICE_ATTRIBUTE_GENERIC_COMPRESSION_SUPPORTED: Device supports compressible memory allocation via ::cuMemCreate + * - ::CU_DEVICE_ATTRIBUTE_MAX_PERSISTING_L2_CACHE_SIZE: Maximum L2 persisting lines capacity setting in bytes + * - ::CU_DEVICE_ATTRIBUTE_MAX_ACCESS_POLICY_WINDOW_SIZE: Maximum value of CUaccessPolicyWindow::num_bytes + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED: Device supports specifying the GPUDirect RDMA flag with ::cuMemCreate. + * - ::CU_DEVICE_ATTRIBUTE_RESERVED_SHARED_MEMORY_PER_BLOCK: Amount of shared memory per block reserved by CUDA driver in bytes + * - ::CU_DEVICE_ATTRIBUTE_SPARSE_CUDA_ARRAY_SUPPORTED: Device supports sparse CUDA arrays and sparse CUDA mipmapped arrays. + * - ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED: Device supports using the ::cuMemHostRegister flag ::CU_MEMHOSTERGISTER_READ_ONLY to register memory that must be mapped as read-only to the GPU + * - ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED: Device supports using the ::cuMemAllocAsync and ::cuMemPool family of APIs + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_SUPPORTED: Device supports GPUDirect RDMA APIs, like nvidia_p2p_get_pages (see https://docs.nvidia.com/cuda/gpudirect-rdma for more information) + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_FLUSH_WRITES_OPTIONS: The returned attribute shall be interpreted as a bitmask, where the individual bits are described by the ::CUflushGPUDirectRDMAWritesOptions enum + * - ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING: GPUDirect RDMA writes to the device do not need to be flushed for consumers within the scope indicated by the returned attribute. See ::CUGPUDirectRDMAWritesOrdering for the numerical values returned here. + * - ::CU_DEVICE_ATTRIBUTE_MEMPOOL_SUPPORTED_HANDLE_TYPES: Bitmask of handle types supported with mempool based IPC + + * - ::CU_DEVICE_ATTRIBUTE_DEFERRED_MAPPING_CUDA_ARRAY_SUPPORTED: Device supports deferred mapping CUDA arrays and CUDA mipmapped arrays. + + * + * \param pi - Returned device attribute value + * \param attrib - Device attribute to query + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem, + * ::cuDeviceGetExecAffinitySupport, + * ::cudaDeviceGetAttribute, + * ::cudaGetDeviceProperties + */ +CUresult CUDAAPI cuDeviceGetAttribute(int *pi, CUdevice_attribute attrib, CUdevice dev); + +/** + * \brief Return NvSciSync attributes that this device can support. + * + * Returns in \p nvSciSyncAttrList, the properties of NvSciSync that + * this CUDA device, \p dev can support. The returned \p nvSciSyncAttrList + * can be used to create an NvSciSync object that matches this device's capabilities. + * + * If NvSciSyncAttrKey_RequiredPerm field in \p nvSciSyncAttrList is + * already set this API will return ::CUDA_ERROR_INVALID_VALUE. + * + * The applications should set \p nvSciSyncAttrList to a valid + * NvSciSyncAttrList failing which this API will return + * ::CUDA_ERROR_INVALID_HANDLE. + * + * The \p flags controls how applications intends to use + * the NvSciSync created from the \p nvSciSyncAttrList. The valid flags are: + * - ::CUDA_NVSCISYNC_ATTR_SIGNAL, specifies that the applications intends to + * signal an NvSciSync on this CUDA device. + * - ::CUDA_NVSCISYNC_ATTR_WAIT, specifies that the applications intends to + * wait on an NvSciSync on this CUDA device. + * + * At least one of these flags must be set, failing which the API + * returns ::CUDA_ERROR_INVALID_VALUE. Both the flags are orthogonal + * to one another: a developer may set both these flags that allows to + * set both wait and signal specific attributes in the same \p nvSciSyncAttrList. + * + * \param nvSciSyncAttrList - Return NvSciSync attributes supported. + * \param dev - Valid Cuda Device to get NvSciSync attributes for. + * \param flags - flags describing NvSciSync usage. + * + * \return + * + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa + * ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuDeviceGetNvSciSyncAttributes(void *nvSciSyncAttrList, CUdevice dev, int flags); + +/** + * \brief Sets the current memory pool of a device + * + * The memory pool must be local to the specified device. + * ::cuMemAllocAsync allocates from the current mempool of the provided stream's device. + * By default, a device's current memory pool is its default memory pool. + * + * \note Use ::cuMemAllocFromPoolAsync to specify asynchronous allocations from a device different + * than the one the stream runs on. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, ::cuMemPoolDestroy, ::cuMemAllocFromPoolAsync + */ +CUresult CUDAAPI cuDeviceSetMemPool(CUdevice dev, CUmemoryPool pool); + +/** + * \brief Gets the current mempool for a device + * + * Returns the last pool provided to ::cuDeviceSetMemPool for this device + * or the device's default memory pool if ::cuDeviceSetMemPool has never been called. + * By default the current mempool is the default mempool for a device. + * Otherwise the returned pool must have been set with ::cuDeviceSetMemPool. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuDeviceGetDefaultMemPool, ::cuMemPoolCreate, ::cuDeviceSetMemPool + */ +CUresult CUDAAPI cuDeviceGetMemPool(CUmemoryPool *pool, CUdevice dev); + +/** + * \brief Returns the default mempool of a device + * + * The default mempool of a device contains device memory from that device. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuMemAllocAsync, ::cuMemPoolTrimTo, ::cuMemPoolGetAttribute, ::cuMemPoolSetAttribute, cuMemPoolSetAccess, ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuDeviceGetDefaultMemPool(CUmemoryPool *pool_out, CUdevice dev); + +/** + * \brief Blocks until remote writes are visible to the specified scope + * + * Blocks until GPUDirect RDMA writes to the target context via mappings + * created through APIs like nvidia_p2p_get_pages (see + * https://docs.nvidia.com/cuda/gpudirect-rdma for more information), are + * visible to the specified scope. + * + * If the scope equals or lies within the scope indicated by + * ::CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WRITES_ORDERING, the call + * will be a no-op and can be safely omitted for performance. This can be + * determined by comparing the numerical values between the two enums, with + * smaller scopes having smaller values. + * + * Users may query support for this API via + * ::CU_DEVICE_ATTRIBUTE_FLUSH_FLUSH_GPU_DIRECT_RDMA_OPTIONS. + * + * \param target - The target of the operation, see ::CUflushGPUDirectRDMAWritesTarget + * \param scope - The scope of the operation, see ::CUflushGPUDirectRDMAWritesScope + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + */ +CUresult CUDAAPI cuFlushGPUDirectRDMAWrites(CUflushGPUDirectRDMAWritesTarget target, CUflushGPUDirectRDMAWritesScope scope); + +/** @} */ /* END CUDA_DEVICE */ + +/** + * \defgroup CUDA_DEVICE_DEPRECATED Device Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated device management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the device management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns properties for a selected device + * + * \deprecated + * + * This function was deprecated as of CUDA 5.0 and replaced by ::cuDeviceGetAttribute(). + * + * Returns in \p *prop the properties of device \p dev. The ::CUdevprop + * structure is defined as: + * + * \code + typedef struct CUdevprop_st { + int maxThreadsPerBlock; + int maxThreadsDim[3]; + int maxGridSize[3]; + int sharedMemPerBlock; + int totalConstantMemory; + int SIMDWidth; + int memPitch; + int regsPerBlock; + int clockRate; + int textureAlign + } CUdevprop; + * \endcode + * where: + * + * - ::maxThreadsPerBlock is the maximum number of threads per block; + * - ::maxThreadsDim[3] is the maximum sizes of each dimension of a block; + * - ::maxGridSize[3] is the maximum sizes of each dimension of a grid; + * - ::sharedMemPerBlock is the total amount of shared memory available per + * block in bytes; + * - ::totalConstantMemory is the total amount of constant memory available on + * the device in bytes; + * - ::SIMDWidth is the warp size; + * - ::memPitch is the maximum pitch allowed by the memory copy functions that + * involve memory regions allocated through ::cuMemAllocPitch(); + * - ::regsPerBlock is the total number of registers available per block; + * - ::clockRate is the clock frequency in kilohertz; + * - ::textureAlign is the alignment requirement; texture base addresses that + * are aligned to ::textureAlign bytes do not need an offset applied to + * texture fetches. + * + * \param prop - Returned properties of device + * \param dev - Device to get properties for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceGetProperties(CUdevprop *prop, CUdevice dev); + +/** + * \brief Returns the compute capability of the device + * + * \deprecated + * + * This function was deprecated as of CUDA 5.0 and its functionality superceded + * by ::cuDeviceGetAttribute(). + * + * Returns in \p *major and \p *minor the major and minor revision numbers that + * define the compute capability of the device \p dev. + * + * \param major - Major revision number + * \param minor - Minor revision number + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuDeviceComputeCapability(int *major, int *minor, CUdevice dev); + +/** @} */ /* END CUDA_DEVICE_DEPRECATED */ + +/** + * \defgroup CUDA_PRIMARY_CTX Primary Context Management + * + * ___MANBRIEF___ primary context management functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the primary context management functions of the low-level + * CUDA driver application programming interface. + * + * The primary context is unique per device and shared with the CUDA runtime API. + * These functions allow integration with other libraries using CUDA. + * + * @{ + */ + +/** + * \brief Retain the primary context on the GPU + * + * Retains the primary context on the device. + * Once the user successfully retains the primary context, the primary context + * will be active and available to the user until the user releases it + * with ::cuDevicePrimaryCtxRelease() or resets it with ::cuDevicePrimaryCtxReset(). + * Unlike ::cuCtxCreate() the newly retained context is not pushed onto the stack. + * + * Retaining the primary context for the first time will fail with ::CUDA_ERROR_UNKNOWN + * if the compute mode of the device is ::CU_COMPUTEMODE_PROHIBITED. The function + * ::cuDeviceGetAttribute() can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to + * determine the compute mode of the device. + * The nvidia-smi tool can be used to set the compute mode for + * devices. Documentation for nvidia-smi can be obtained by passing a + * -h option to it. + * + * Please note that the primary context always supports pinned allocations. Other + * flags can be specified by ::cuDevicePrimaryCtxSetFlags(). + * + * \param pctx - Returned context handle of the new context + * \param dev - Device for which primary context is requested + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRelease, + * ::cuDevicePrimaryCtxSetFlags, + * ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuDevicePrimaryCtxRetain(CUcontext *pctx, CUdevice dev); + +/** + * \brief Release the primary context on the GPU + * + * Releases the primary context interop on the device. + * A retained context should always be released once the user is done using + * it. The context is automatically reset once the last reference to it is + * released. This behavior is different when the primary context was retained + * by the CUDA runtime from CUDA 4.0 and earlier. In this case, the primary + * context remains always active. + * + * Releasing a primary context that has not been previously retained will + * fail with ::CUDA_ERROR_INVALID_CONTEXT. + * + * Please note that unlike ::cuCtxDestroy() this method does not pop the context + * from stack in any circumstances. + * + * \param dev - Device which primary context is released + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRetain, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); + +/** + * \brief Set flags for the primary context + * + * Sets the flags for the primary context on the device overwriting perviously + * set ones. + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * + * \param dev - Device for which the primary context flags are set + * \param flags - New flags for the device + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRetain, + * ::cuDevicePrimaryCtxGetState, + * ::cuCtxCreate, + * ::cuCtxGetFlags, + * ::cudaSetDeviceFlags + */ +CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); + +/** + * \brief Get the state of the primary context + * + * Returns in \p *flags the flags for the primary context of \p dev, and in + * \p *active whether it is active. See ::cuDevicePrimaryCtxSetFlags for flag + * values. + * + * \param dev - Device to get primary context flags for + * \param flags - Pointer to store flags + * \param active - Pointer to store context state; 0 = inactive, 1 = active + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa + * ::cuDevicePrimaryCtxSetFlags, + * ::cuCtxGetFlags, + * ::cudaGetDeviceFlags + */ +CUresult CUDAAPI cuDevicePrimaryCtxGetState(CUdevice dev, unsigned int *flags, int *active); + +/** + * \brief Destroy all allocations and reset all state on the primary context + * + * Explicitly destroys and cleans up all resources associated with the current + * device in the current process. + * + * Note that it is responsibility of the calling function to ensure that no + * other module in the process is using the device any more. For that reason + * it is recommended to use ::cuDevicePrimaryCtxRelease() in most cases. + * However it is safe for other modules to call ::cuDevicePrimaryCtxRelease() + * even after resetting the device. + * Resetting the primary context does not release it, an application that has + * retained the primary context should explicitly release its usage. + * + * \param dev - Device for which primary context is destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE + * \notefnerr + * + * \sa ::cuDevicePrimaryCtxRetain, + * ::cuDevicePrimaryCtxRelease, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaDeviceReset + */ +CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); + +/** @} */ /* END CUDA_PRIMARY_CTX */ + +/** + * \brief Returns information about the execution affinity support of the device. + * + * Returns in \p *pi whether execution affinity type \p type is supported by device \p dev. + * The supported types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: 1 if context with limited SMs is supported by the device, + * or 0 if not; + * + * \param pi - 1 if the execution affinity type \p type is supported by the device, or 0 if not + * \param type - Execution affinity type to query + * \param dev - Device handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGetAttribute, + * ::cuDeviceGetCount, + * ::cuDeviceGetName, + * ::cuDeviceGetUuid, + * ::cuDeviceGet, + * ::cuDeviceTotalMem + */ +CUresult CUDAAPI cuDeviceGetExecAffinitySupport(int *pi, CUexecAffinityType type, CUdevice dev); + +/** + * \defgroup CUDA_CTX Context Management + * + * ___MANBRIEF___ context management functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the context management functions of the low-level + * CUDA driver application programming interface. + * + * Please note that some functions are described in + * \ref CUDA_PRIMARY_CTX "Primary Context Management" section. + * + * @{ + */ + +/** + * \brief Create a CUDA context + * + * \note In most cases it is recommended to use ::cuDevicePrimaryCtxRetain. + * + * Creates a new CUDA context and associates it with the calling thread. The + * \p flags parameter is described below. The context is created with a usage + * count of 1 and the caller of ::cuCtxCreate() must call ::cuCtxDestroy() or + * when done using the context. If a context is already current to the thread, + * it is supplanted by the newly created context and may be restored by a subsequent + * call to ::cuCtxPopCurrent(). + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_MAP_HOST: Instruct CUDA to support mapped pinned allocations. + * This flag must be set in order to allocate pinned host memory that is + * accessible to the GPU. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). + * + * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of + * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() + * can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to determine the + * compute mode of the device. The nvidia-smi tool can be used to set + * the compute mode for * devices. + * Documentation for nvidia-smi can be obtained by passing a + * -h option to it. + * + * \param pctx - Returned context handle of the new context + * \param flags - Context creation flags + * \param dev - Device to create context on + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); + +/** + * \brief Create a CUDA context with execution affinity + * + * Creates a new CUDA context with execution affinity and associates it with + * the calling thread. The \p paramsArray and \p flags parameter are described below. + * The context is created with a usage count of 1 and the caller of ::cuCtxCreate() must + * call ::cuCtxDestroy() or when done using the context. If a context is already + * current to the thread, it is supplanted by the newly created context and may + * be restored by a subsequent call to ::cuCtxPopCurrent(). + * + * The type and the amount of execution resource the context can use is limited by \p paramsArray + * and \p numParams. The \p paramsArray is an array of \p CUexecAffinityParam and the \p numParams + * describes the size of the array. If two \p CUexecAffinityParam in the array have the same type, + * the latter execution affinity parameter overrides the former execution affinity parameter. + * The supported execution affinity types are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT limits the portion of SMs that the context can use. The portion + * of SMs is specified as the number of SMs via \p CUexecAffinitySmCount. This limit will be internally + * rounded up to the next hardware-supported amount. Hence, it is imperative to query the actual execution + * affinity of the context via \p cuCtxGetExecAffinity after context creation. Currently, this attribute + * is only supported under Volta+ MPS. + * + * The three LSBs of the \p flags parameter can be used to control how the OS + * thread, which owns the CUDA context at the time of an API call, interacts + * with the OS scheduler when waiting for results from the GPU. Only one of + * the scheduling flags can be set when creating a context. + * + * - ::CU_CTX_SCHED_SPIN: Instruct CUDA to actively spin when waiting for + * results from the GPU. This can decrease latency when waiting for the GPU, + * but may lower the performance of CPU threads if they are performing work in + * parallel with the CUDA thread. + * + * - ::CU_CTX_SCHED_YIELD: Instruct CUDA to yield its thread when waiting for + * results from the GPU. This can increase latency when waiting for the GPU, + * but can increase the performance of CPU threads performing work in parallel + * with the GPU. + * + * - ::CU_CTX_SCHED_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work. + * + * - ::CU_CTX_BLOCKING_SYNC: Instruct CUDA to block the CPU thread on a + * synchronization primitive when waiting for the GPU to finish work.
+ * Deprecated: This flag was deprecated as of CUDA 4.0 and was + * replaced with ::CU_CTX_SCHED_BLOCKING_SYNC. + * + * - ::CU_CTX_SCHED_AUTO: The default value if the \p flags parameter is zero, + * uses a heuristic based on the number of active CUDA contexts in the + * process \e C and the number of logical processors in the system \e P. If + * \e C > \e P, then CUDA will yield to other OS threads when waiting for + * the GPU (::CU_CTX_SCHED_YIELD), otherwise CUDA will not yield while + * waiting for results and actively spin on the processor (::CU_CTX_SCHED_SPIN). + * Additionally, on Tegra devices, ::CU_CTX_SCHED_AUTO uses a heuristic based on + * the power profile of the platform and may choose ::CU_CTX_SCHED_BLOCKING_SYNC + * for low-powered devices. + * + * - ::CU_CTX_MAP_HOST: Instruct CUDA to support mapped pinned allocations. + * This flag must be set in order to allocate pinned host memory that is + * accessible to the GPU. + * + * - ::CU_CTX_LMEM_RESIZE_TO_MAX: Instruct CUDA to not reduce local memory + * after resizing local memory for a kernel. This can prevent thrashing by + * local memory allocations when launching many kernels with high local + * memory usage at the cost of potentially increased memory usage.
+ * Deprecated: This flag is deprecated and the behavior enabled + * by this flag is now the default and cannot be disabled. + * Instead, the per-thread stack size can be controlled with ::cuCtxSetLimit(). + * + * Context creation will fail with ::CUDA_ERROR_UNKNOWN if the compute mode of + * the device is ::CU_COMPUTEMODE_PROHIBITED. The function ::cuDeviceGetAttribute() + * can be used with ::CU_DEVICE_ATTRIBUTE_COMPUTE_MODE to determine the + * compute mode of the device. The nvidia-smi tool can be used to set + * the compute mode for * devices. + * Documentation for nvidia-smi can be obtained by passing a + * -h option to it. + * + * \param pctx - Returned context handle of the new context + * \param paramsArray - Execution affinity parameters + * \param numParams - Number of execution affinity parameters + * \param flags - Context creation flags + * \param dev - Device to create context on + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::CUexecAffinityParam + */ +CUresult CUDAAPI cuCtxCreate_v3(CUcontext *pctx, CUexecAffinityParam *paramsArray, int numParams, unsigned int flags, CUdevice dev); + +/** + * \brief Destroy a CUDA context + * + * Destroys the CUDA context specified by \p ctx. The context \p ctx will be + * destroyed regardless of how many threads it is current to. + * It is the responsibility of the calling function to ensure that no API + * call issues using \p ctx while ::cuCtxDestroy() is executing. + * + * Destroys and cleans up all resources associated with the context. + * It is the caller's responsibility to ensure that the context or its resources + * are not accessed or passed in subsequent API calls and doing so will result in undefined behavior. + * These resources include CUDA types such as ::CUmodule, ::CUfunction, ::CUstream, ::CUevent, + * ::CUarray, ::CUmipmappedArray, ::CUtexObject, ::CUsurfObject, ::CUtexref, ::CUsurfref, + * ::CUgraphicsResource, ::CUlinkState, ::CUexternalMemory and ::CUexternalSemaphore. + * + * If \p ctx is current to the calling thread then \p ctx will also be + * popped from the current thread's context stack (as though ::cuCtxPopCurrent() + * were called). If \p ctx is current to other threads, then \p ctx will + * remain current to those threads, and attempting to access \p ctx from + * those threads will result in the error ::CUDA_ERROR_CONTEXT_IS_DESTROYED. + * + * \param ctx - Context to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); + +/** + * \brief Pushes a context on the current CPU thread + * + * Pushes the given context \p ctx onto the CPU thread's stack of current + * contexts. The specified context becomes the CPU thread's current context, so + * all CUDA functions that operate on the current context are affected. + * + * The previous current context may be made current again by calling + * ::cuCtxDestroy() or ::cuCtxPopCurrent(). + * + * \param ctx - Context to push + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx); + +/** + * \brief Pops the current CUDA context from the current CPU thread. + * + * Pops the current CUDA context from the CPU thread and passes back the + * old context handle in \p *pctx. That context may then be made current + * to a different CPU thread by calling ::cuCtxPushCurrent(). + * + * If a context was current to the CPU thread before ::cuCtxCreate() or + * ::cuCtxPushCurrent() was called, this function makes that context current to + * the CPU thread again. + * + * \param pctx - Returned popped context handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx); + +/** + * \brief Binds the specified CUDA context to the calling CPU thread + * + * Binds the specified CUDA context to the calling CPU thread. + * If \p ctx is NULL then the CUDA context previously bound to the + * calling CPU thread is unbound and ::CUDA_SUCCESS is returned. + * + * If there exists a CUDA context stack on the calling CPU thread, this + * will replace the top of that stack with \p ctx. + * If \p ctx is NULL then this will be equivalent to popping the top + * of the calling CPU thread's CUDA context stack (or a no-op if the + * calling CPU thread's CUDA context stack is empty). + * + * \param ctx - Context to bind to the calling CPU thread + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa + * ::cuCtxGetCurrent, + * ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cudaSetDevice + */ +CUresult CUDAAPI cuCtxSetCurrent(CUcontext ctx); + +/** + * \brief Returns the CUDA context bound to the calling CPU thread. + * + * Returns in \p *pctx the CUDA context bound to the calling CPU thread. + * If no context is bound to the calling CPU thread then \p *pctx is + * set to NULL and ::CUDA_SUCCESS is returned. + * + * \param pctx - Returned context handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * \notefnerr + * + * \sa + * ::cuCtxSetCurrent, + * ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cudaGetDevice + */ +CUresult CUDAAPI cuCtxGetCurrent(CUcontext *pctx); + +/** + * \brief Returns the device ID for the current context + * + * Returns in \p *device the ordinal of the current context's device. + * + * \param device - Returned device ID for the current context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaGetDevice + */ +CUresult CUDAAPI cuCtxGetDevice(CUdevice *device); + +/** + * \brief Returns the flags for the current context + * + * Returns in \p *flags the flags of the current context. See ::cuCtxCreate + * for flag values. + * + * \param flags - Pointer to store flags of current context + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetCurrent, + * ::cuCtxGetDevice, + * ::cuCtxGetLimit, + * ::cuCtxGetSharedMemConfig, + * ::cuCtxGetStreamPriorityRange, + * ::cudaGetDeviceFlags + */ +CUresult CUDAAPI cuCtxGetFlags(unsigned int *flags); + +/** + * \brief Block for a context's tasks to complete + * + * Blocks until the device has completed all preceding requested tasks. + * ::cuCtxSynchronize() returns an error if one of the preceding tasks failed. + * If the context was created with the ::CU_CTX_SCHED_BLOCKING_SYNC flag, the + * CPU thread will block until the GPU context has finished its work. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cudaDeviceSynchronize + */ +CUresult CUDAAPI cuCtxSynchronize(void); + +/** + * \brief Set resource limits + * + * Setting \p limit to \p value is a request by the application to update + * the current limit maintained by the context. The driver is free to + * modify the requested value to meet h/w requirements (this could be + * clamping to minimum or maximum values, rounding up to nearest element + * size, etc). The application can use ::cuCtxGetLimit() to find out exactly + * what the limit has been set to. + * + * Setting each ::CUlimit has its own specific restrictions, so each is + * discussed here. + * + * - ::CU_LIMIT_STACK_SIZE controls the stack size in bytes of each GPU thread. + * The driver automatically increases the per-thread stack size + * for each kernel launch as needed. This size isn't reset back to the + * original value after each launch. Setting this value will take effect + * immediately, and if necessary, the device will block until all preceding + * requested tasks are complete. + * + * - ::CU_LIMIT_PRINTF_FIFO_SIZE controls the size in bytes of the FIFO used + * by the ::printf() device system call. Setting ::CU_LIMIT_PRINTF_FIFO_SIZE + * must be performed before launching any kernel that uses the ::printf() + * device system call, otherwise ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * - ::CU_LIMIT_MALLOC_HEAP_SIZE controls the size in bytes of the heap used + * by the ::malloc() and ::free() device system calls. Setting + * ::CU_LIMIT_MALLOC_HEAP_SIZE must be performed before launching any kernel + * that uses the ::malloc() or ::free() device system calls, otherwise + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * - ::CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH controls the maximum nesting depth of + * a grid at which a thread can safely call ::cudaDeviceSynchronize(). Setting + * this limit must be performed before any launch of a kernel that uses the + * device runtime and calls ::cudaDeviceSynchronize() above the default sync + * depth, two levels of grids. Calls to ::cudaDeviceSynchronize() will fail + * with error code ::cudaErrorSyncDepthExceeded if the limitation is + * violated. This limit can be set smaller than the default or up the maximum + * launch depth of 24. When setting this limit, keep in mind that additional + * levels of sync depth require the driver to reserve large amounts of device + * memory which can no longer be used for user allocations. If these + * reservations of device memory fail, ::cuCtxSetLimit() will return + * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. + * This limit is only applicable to devices of compute capability 3.5 and + * higher. Attempting to set this limit on devices of compute capability less + * than 3.5 will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being + * returned. + * + * - ::CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT controls the maximum number of + * outstanding device runtime launches that can be made from the current + * context. A grid is outstanding from the point of launch up until the grid + * is known to have been completed. Device runtime launches which violate + * this limitation fail and return ::cudaErrorLaunchPendingCountExceeded when + * ::cudaGetLastError() is called after launch. If more pending launches than + * the default (2048 launches) are needed for a module using the device + * runtime, this limit can be increased. Keep in mind that being able to + * sustain additional pending launches will require the driver to reserve + * larger amounts of device memory upfront which can no longer be used for + * allocations. If these reservations fail, ::cuCtxSetLimit() will return + * ::CUDA_ERROR_OUT_OF_MEMORY, and the limit can be reset to a lower value. + * This limit is only applicable to devices of compute capability 3.5 and + * higher. Attempting to set this limit on devices of compute capability less + * than 3.5 will result in the error ::CUDA_ERROR_UNSUPPORTED_LIMIT being + * returned. + * + * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY controls the L2 cache fetch granularity. + * Values can range from 0B to 128B. This is purely a performence hint and + * it can be ignored or clamped depending on the platform. + * + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE controls size in bytes availabe for + * persisting L2 cache. This is purely a performance hint and it can be + * ignored or clamped depending on the platform. + * + * \param limit - Limit to set + * \param value - Size of limit + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_LIMIT, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSynchronize, + * ::cudaDeviceSetLimit + */ +CUresult CUDAAPI cuCtxSetLimit(CUlimit limit, size_t value); + +/** + * \brief Returns resource limits + * + * Returns in \p *pvalue the current size of \p limit. The supported + * ::CUlimit values are: + * - ::CU_LIMIT_STACK_SIZE: stack size in bytes of each GPU thread. + * - ::CU_LIMIT_PRINTF_FIFO_SIZE: size in bytes of the FIFO used by the + * ::printf() device system call. + * - ::CU_LIMIT_MALLOC_HEAP_SIZE: size in bytes of the heap used by the + * ::malloc() and ::free() device system calls. + * - ::CU_LIMIT_DEV_RUNTIME_SYNC_DEPTH: maximum grid depth at which a thread + * can issue the device runtime call ::cudaDeviceSynchronize() to wait on + * child grid launches to complete. + * - ::CU_LIMIT_DEV_RUNTIME_PENDING_LAUNCH_COUNT: maximum number of outstanding + * device runtime launches that can be made from this context. + * - ::CU_LIMIT_MAX_L2_FETCH_GRANULARITY: L2 cache fetch granularity. + * - ::CU_LIMIT_PERSISTING_L2_CACHE_SIZE: Persisting L2 cache size in bytes + * + * \param limit - Limit to query + * \param pvalue - Returned size of limit + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_LIMIT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaDeviceGetLimit + */ +CUresult CUDAAPI cuCtxGetLimit(size_t *pvalue, CUlimit limit); + +/** + * \brief Returns the preferred cache configuration for the current context. + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this function returns through \p pconfig the preferred cache configuration + * for the current context. This is only a preference. The driver will use + * the requested configuration if possible, but it is free to choose a different + * configuration if required to execute functions. + * + * This will return a \p pconfig of ::CU_FUNC_CACHE_PREFER_NONE on devices + * where the size of the L1 cache and shared memory are fixed. + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \param pconfig - Returned cache configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceGetCacheConfig + */ +CUresult CUDAAPI cuCtxGetCacheConfig(CUfunc_cache *pconfig); + +/** + * \brief Sets the preferred cache configuration for the current context. + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets through \p config the preferred cache configuration for + * the current context. This is only a preference. The driver will use + * the requested configuration if possible, but it is free to choose a different + * configuration if required to execute the function. Any function preference + * set via ::cuFuncSetCacheConfig() will be preferred over this context-wide + * setting. Setting the context-wide cache configuration to + * ::CU_FUNC_CACHE_PREFER_NONE will cause subsequent kernel launches to prefer + * to not change the cache configuration unless required to launch the kernel. + * + * This setting does nothing on devices where the size of the L1 cache and + * shared memory are fixed. + * + * Launching a kernel with a different preference than the most recent + * preference setting may insert a device-side synchronization point. + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \param config - Requested cache configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceSetCacheConfig + */ +CUresult CUDAAPI cuCtxSetCacheConfig(CUfunc_cache config); + +/** + * \brief Returns the current shared memory configuration for the current context. + * + * This function will return in \p pConfig the current size of shared memory banks + * in the current context. On devices with configurable shared memory banks, + * ::cuCtxSetSharedMemConfig can be used to change this setting, so that all + * subsequent kernel launches will by default use the new bank size. When + * ::cuCtxGetSharedMemConfig is called on devices without configurable shared + * memory, it will return the fixed bank size of the hardware. + * + * The returned bank configurations can be either: + * - ::CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE: shared memory bank width is + * four bytes. + * - ::CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE: shared memory bank width will + * eight bytes. + * + * \param pConfig - returned shared memory configuration + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuCtxGetSharedMemConfig, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceGetSharedMemConfig + */ +CUresult CUDAAPI cuCtxGetSharedMemConfig(CUsharedconfig *pConfig); + +/** + * \brief Sets the shared memory configuration for the current context. + * + * On devices with configurable shared memory banks, this function will set + * the context's shared memory bank size which is used for subsequent kernel + * launches. + * + * Changed the shared memory configuration between launches may insert a device + * side synchronization point between those launches. + * + * Changing the shared memory bank size will not increase shared memory usage + * or affect occupancy of kernels, but may have major effects on performance. + * Larger bank sizes will allow for greater potential bandwidth to shared memory, + * but will change what kinds of accesses to shared memory will result in bank + * conflicts. + * + * This function will do nothing on devices with fixed shared memory bank size. + * + * The supported bank configurations are: + * - ::CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE: set bank width to the default initial + * setting (currently, four bytes). + * - ::CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE: set shared memory bank width to + * be natively four bytes. + * - ::CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE: set shared memory bank width to + * be natively eight bytes. + * + * \param config - requested shared memory configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cuCtxGetSharedMemConfig, + * ::cuFuncSetCacheConfig, + * ::cudaDeviceSetSharedMemConfig + */ +CUresult CUDAAPI cuCtxSetSharedMemConfig(CUsharedconfig config); + +/** + * \brief Gets the context's API version. + * + * Returns a version number in \p version corresponding to the capabilities of + * the context (e.g. 3010 or 3020), which library developers can use to direct + * callers to a specific API version. If \p ctx is NULL, returns the API version + * used to create the currently bound context. + * + * Note that new API versions are only introduced when context capabilities are + * changed that break binary compatibility, so the API version and driver version + * may be different. For example, it is valid for the API version to be 3020 while + * the driver version is 4020. + * + * \param ctx - Context to check + * \param version - Pointer to version + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +CUresult CUDAAPI cuCtxGetApiVersion(CUcontext ctx, unsigned int *version); + +/** + * \brief Returns numerical values that correspond to the least and + * greatest stream priorities. + * + * Returns in \p *leastPriority and \p *greatestPriority the numerical values that correspond + * to the least and greatest stream priorities respectively. Stream priorities + * follow a convention where lower numbers imply greater priorities. The range of + * meaningful stream priorities is given by [\p *greatestPriority, \p *leastPriority]. + * If the user attempts to create a stream with a priority value that is + * outside the meaningful range as specified by this API, the priority is + * automatically clamped down or up to either \p *leastPriority or \p *greatestPriority + * respectively. See ::cuStreamCreateWithPriority for details on creating a + * priority stream. + * A NULL may be passed in for \p *leastPriority or \p *greatestPriority if the value + * is not desired. + * + * This function will return '0' in both \p *leastPriority and \p *greatestPriority if + * the current context's device does not support stream priorities + * (see ::cuDeviceGetAttribute). + * + * \param leastPriority - Pointer to an int in which the numerical value for least + * stream priority is returned + * \param greatestPriority - Pointer to an int in which the numerical value for greatest + * stream priority is returned + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \notefnerr + * + * \sa ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize, + * ::cudaDeviceGetStreamPriorityRange + */ +CUresult CUDAAPI cuCtxGetStreamPriorityRange(int *leastPriority, int *greatestPriority); + +/** + * \brief Resets all persisting lines in cache to normal status. + * + * ::cuCtxResetPersistingL2Cache Resets all persisting lines in cache to normal + * status. Takes effect on function return. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuCtxResetPersistingL2Cache(void); + +/** + * \brief Returns the execution affinity setting for the current context. + * + * Returns in \p *pExecAffinity the current value of \p type. The supported + * ::CUexecAffinityType values are: + * - ::CU_EXEC_AFFINITY_TYPE_SM_COUNT: number of SMs the context is limited to use. + * + * \param type - Execution affinity type to query + * \param pExecAffinity - Returned execution affinity + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNSUPPORTED_EXEC_AFFINITY + * \notefnerr + * + * \sa + * ::CUexecAffinityParam + */ +CUresult CUDAAPI cuCtxGetExecAffinity(CUexecAffinityParam *pExecAffinity, CUexecAffinityType type); + + +/** @} */ /* END CUDA_CTX */ + +/** + * \defgroup CUDA_CTX_DEPRECATED Context Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated context management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated context management functions of the low-level + * CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Increment a context's usage-count + * + * \deprecated + * + * Note that this function is deprecated and should not be used. + * + * Increments the usage count of the context and passes back a context handle + * in \p *pctx that must be passed to ::cuCtxDetach() when the application is + * done with the context. ::cuCtxAttach() fails if there is no context current + * to the thread. + * + * Currently, the \p flags parameter must be 0. + * + * \param pctx - Returned context handle of the current context + * \param flags - Context attach flags (must be 0) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxDetach, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxAttach(CUcontext *pctx, unsigned int flags); + +/** + * \brief Decrement a context's usage-count + * + * \deprecated + * + * Note that this function is deprecated and should not be used. + * + * Decrements the usage count of the context \p ctx, and destroys the context + * if the usage count goes to 0. The context must be a handle that was passed + * back by ::cuCtxCreate() or ::cuCtxAttach(), and must be current to the + * calling thread. + * + * \param ctx - Context to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxCreate, + * ::cuCtxDestroy, + * ::cuCtxGetApiVersion, + * ::cuCtxGetCacheConfig, + * ::cuCtxGetDevice, + * ::cuCtxGetFlags, + * ::cuCtxGetLimit, + * ::cuCtxPopCurrent, + * ::cuCtxPushCurrent, + * ::cuCtxSetCacheConfig, + * ::cuCtxSetLimit, + * ::cuCtxSynchronize + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuCtxDetach(CUcontext ctx); + +/** @} */ /* END CUDA_CTX_DEPRECATED */ + + +/** + * \defgroup CUDA_MODULE Module Management + * + * ___MANBRIEF___ module management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the module management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Loads a compute module + * + * Takes a filename \p fname and loads the corresponding module \p module into + * the current context. The CUDA driver API does not attempt to lazily + * allocate the resources needed by a module; if the memory for functions and + * data (constant and global) needed by the module cannot be allocated, + * ::cuModuleLoad() fails. The file should be a \e cubin file as output by + * \b nvcc, or a \e PTX file either as output by \b nvcc or handwritten, or + * a \e fatbin file as output by \b nvcc from toolchain 4.0 or later. + * + * \param module - Returned module + * \param fname - Filename of module to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_FILE_NOT_FOUND, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoad(CUmodule *module, const char *fname); + +/** + * \brief Load a module's data + * + * Takes a pointer \p image and loads the corresponding module \p module into + * the current context. The pointer may be obtained by mapping a \e cubin or + * \e PTX or \e fatbin file, passing a \e cubin or \e PTX or \e fatbin file + * as a NULL-terminated text string, or incorporating a \e cubin or \e fatbin + * object into the executable resources and using operating system calls such + * as Windows \c FindResource() to obtain the pointer. + * + * \param module - Returned module + * \param image - Module data to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoadData(CUmodule *module, const void *image); + +/** + * \brief Load a module's data with options + * + * Takes a pointer \p image and loads the corresponding module \p module into + * the current context. The pointer may be obtained by mapping a \e cubin or + * \e PTX or \e fatbin file, passing a \e cubin or \e PTX or \e fatbin file + * as a NULL-terminated text string, or incorporating a \e cubin or \e fatbin + * object into the executable resources and using operating system calls such + * as Windows \c FindResource() to obtain the pointer. Options are passed as + * an array via \p options and any corresponding parameters are passed in + * \p optionValues. The number of total options is supplied via \p numOptions. + * Any outputs will be returned via \p optionValues. + * + * \param module - Returned module + * \param image - Module data to load + * \param numOptions - Number of options + * \param options - Options for JIT + * \param optionValues - Option values for JIT + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoadDataEx(CUmodule *module, const void *image, unsigned int numOptions, CUjit_option *options, void **optionValues); + +/** + * \brief Load a module's data + * + * Takes a pointer \p fatCubin and loads the corresponding module \p module + * into the current context. The pointer represents a fat binary object, + * which is a collection of different \e cubin and/or \e PTX files, all + * representing the same device code, but compiled and optimized for different + * architectures. + * + * Prior to CUDA 4.0, there was no documented API for constructing and using + * fat binary objects by programmers. Starting with CUDA 4.0, fat binary + * objects can be constructed by providing the -fatbin option to \b nvcc. + * More information can be found in the \b nvcc document. + * + * \param module - Returned module + * \param fatCubin - Fat binary to load + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU, + * ::CUDA_ERROR_SHARED_OBJECT_SYMBOL_NOT_FOUND, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleLoadFatBinary(CUmodule *module, const void *fatCubin); + +/** + * \brief Unloads a module + * + * Unloads a module \p hmod from the current context. + * + * \param hmod - Module to unload + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_destroy_ub + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary + */ +CUresult CUDAAPI cuModuleUnload(CUmodule hmod); + +/** + * \brief Returns a function handle + * + * Returns in \p *hfunc the handle of the function of name \p name located in + * module \p hmod. If no function of that name exists, ::cuModuleGetFunction() + * returns ::CUDA_ERROR_NOT_FOUND. + * + * \param hfunc - Returned function handle + * \param hmod - Module to retrieve function from + * \param name - Name of function to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload + */ +CUresult CUDAAPI cuModuleGetFunction(CUfunction *hfunc, CUmodule hmod, const char *name); + +/** + * \brief Returns a global pointer from a module + * + * Returns in \p *dptr and \p *bytes the base pointer and size of the + * global of name \p name located in module \p hmod. If no variable of that name + * exists, ::cuModuleGetGlobal() returns ::CUDA_ERROR_NOT_FOUND. Both + * parameters \p dptr and \p bytes are optional. If one of them is + * NULL, it is ignored. + * + * \param dptr - Returned global device pointer + * \param bytes - Returned global size in bytes + * \param hmod - Module to retrieve global from + * \param name - Name of global to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload, + * ::cudaGetSymbolAddress, + * ::cudaGetSymbolSize + */ +CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr *dptr, size_t *bytes, CUmodule hmod, const char *name); + +/** + * \brief Returns a handle to a texture reference + * + * Returns in \p *pTexRef the handle of the texture reference of name \p name + * in the module \p hmod. If no texture reference of that name exists, + * ::cuModuleGetTexRef() returns ::CUDA_ERROR_NOT_FOUND. This texture reference + * handle should not be destroyed, since it will be destroyed when the module + * is unloaded. + * + * \param pTexRef - Returned texture reference + * \param hmod - Module to retrieve texture reference from + * \param name - Name of texture reference to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetSurfRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload, + * ::cudaGetTextureReference + */ +CUresult CUDAAPI cuModuleGetTexRef(CUtexref *pTexRef, CUmodule hmod, const char *name); + +/** + * \brief Returns a handle to a surface reference + * + * Returns in \p *pSurfRef the handle of the surface reference of name \p name + * in the module \p hmod. If no surface reference of that name exists, + * ::cuModuleGetSurfRef() returns ::CUDA_ERROR_NOT_FOUND. + * + * \param pSurfRef - Returned surface reference + * \param hmod - Module to retrieve surface reference from + * \param name - Name of surface reference to retrieve + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + * \sa ::cuModuleGetFunction, + * ::cuModuleGetGlobal, + * ::cuModuleGetTexRef, + * ::cuModuleLoad, + * ::cuModuleLoadData, + * ::cuModuleLoadDataEx, + * ::cuModuleLoadFatBinary, + * ::cuModuleUnload, + * ::cudaGetSurfaceReference + */ +CUresult CUDAAPI cuModuleGetSurfRef(CUsurfref *pSurfRef, CUmodule hmod, const char *name); + +/** + * \brief Creates a pending JIT linker invocation. + * + * If the call is successful, the caller owns the returned CUlinkState, which + * should eventually be destroyed with ::cuLinkDestroy. The + * device code machine size (32 or 64 bit) will match the calling application. + * + * Both linker and compiler options may be specified. Compiler options will + * be applied to inputs to this linker action which must be compiled from PTX. + * The options ::CU_JIT_WALL_TIME, + * ::CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES, and ::CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES + * will accumulate data until the CUlinkState is destroyed. + * + * \p optionValues must remain valid for the life of the CUlinkState if output + * options are used. No other references to inputs are maintained after this + * call returns. + * + * \param numOptions Size of options arrays + * \param options Array of linker and compiler options + * \param optionValues Array of option values, each cast to void * + * \param stateOut On success, this will contain a CUlinkState to specify + * and complete this action + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_JIT_COMPILER_NOT_FOUND + * \notefnerr + * + * \sa ::cuLinkAddData, + * ::cuLinkAddFile, + * ::cuLinkComplete, + * ::cuLinkDestroy + */ +CUresult CUDAAPI +cuLinkCreate(unsigned int numOptions, CUjit_option *options, void **optionValues, CUlinkState *stateOut); + +/** + * \brief Add an input to a pending linker invocation + * + * Ownership of \p data is retained by the caller. No reference is retained to any + * inputs after this call returns. + * + * This method accepts only compiler options, which are used if the data must + * be compiled from PTX, and does not accept any of + * ::CU_JIT_WALL_TIME, ::CU_JIT_INFO_LOG_BUFFER, ::CU_JIT_ERROR_LOG_BUFFER, + * ::CU_JIT_TARGET_FROM_CUCONTEXT, or ::CU_JIT_TARGET. + * + * \param state A pending linker action. + * \param type The type of the input data. + * \param data The input data. PTX must be NULL-terminated. + * \param size The length of the input data. + * \param name An optional name for this input in log messages. + * \param numOptions Size of options. + * \param options Options to be applied only for this input (overrides options from ::cuLinkCreate). + * \param optionValues Array of option values, each cast to void *. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU + * + * \sa ::cuLinkCreate, + * ::cuLinkAddFile, + * ::cuLinkComplete, + * ::cuLinkDestroy + */ +CUresult CUDAAPI +cuLinkAddData(CUlinkState state, CUjitInputType type, void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, void **optionValues); + +/** + * \brief Add a file input to a pending linker invocation + * + * No reference is retained to any inputs after this call returns. + * + * This method accepts only compiler options, which are used if the input + * must be compiled from PTX, and does not accept any of + * ::CU_JIT_WALL_TIME, ::CU_JIT_INFO_LOG_BUFFER, ::CU_JIT_ERROR_LOG_BUFFER, + * ::CU_JIT_TARGET_FROM_CUCONTEXT, or ::CU_JIT_TARGET. + * + * This method is equivalent to invoking ::cuLinkAddData on the contents + * of the file. + * + * \param state A pending linker action + * \param type The type of the input data + * \param path Path to the input file + * \param numOptions Size of options + * \param options Options to be applied only for this input (overrides options from ::cuLinkCreate) + * \param optionValues Array of option values, each cast to void * + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_FILE_NOT_FOUND + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_PTX, + * ::CUDA_ERROR_UNSUPPORTED_PTX_VERSION, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NO_BINARY_FOR_GPU + * + * \sa ::cuLinkCreate, + * ::cuLinkAddData, + * ::cuLinkComplete, + * ::cuLinkDestroy + */ +CUresult CUDAAPI +cuLinkAddFile(CUlinkState state, CUjitInputType type, const char *path, + unsigned int numOptions, CUjit_option *options, void **optionValues); + +/** + * \brief Complete a pending linker invocation + * + * Completes the pending linker action and returns the cubin image for the linked + * device code, which can be used with ::cuModuleLoadData. The cubin is owned by + * \p state, so it should be loaded before \p state is destroyed via ::cuLinkDestroy. + * This call does not destroy \p state. + * + * \param state A pending linker invocation + * \param cubinOut On success, this will point to the output image + * \param sizeOut Optional parameter to receive the size of the generated image + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuLinkCreate, + * ::cuLinkAddData, + * ::cuLinkAddFile, + * ::cuLinkDestroy, + * ::cuModuleLoadData + */ +CUresult CUDAAPI +cuLinkComplete(CUlinkState state, void **cubinOut, size_t *sizeOut); + +/** + * \brief Destroys state for a JIT linker invocation. + * + * \param state State object for the linker invocation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \sa ::cuLinkCreate + */ +CUresult CUDAAPI +cuLinkDestroy(CUlinkState state); + +/** @} */ /* END CUDA_MODULE */ + + +/** + * \defgroup CUDA_MEM Memory Management + * + * ___MANBRIEF___ memory management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the memory management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Gets free and total memory + * + * Returns in \p *total the total amount of memory available to the the current context. + * Returns in \p *free the amount of memory on the device that is free according to the OS. + * CUDA is not guaranteed to be able to allocate all of the memory that the OS reports as free. + * In a multi-tenet situation, free estimate returned is prone to race condition where + * a new allocation/free done by a different process or a different thread in the same + * process between the time when free memory was estimated and reported, will result in + * deviation in free value reported and actual free memory. + * + * The integrated GPU on Tegra shares memory with CPU and other component + * of the SoC. The free and total values returned by the API excludes + * the SWAP memory space maintained by the OS on some platforms. + * The OS may move some of the memory pages into swap area as the GPU or + * CPU allocate or access memory. See Tegra app note on how to calculate + * total and free memory on Tegra. + * + * \param free - Returned free memory in bytes + * \param total - Returned total memory in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemGetInfo + */ +CUresult CUDAAPI cuMemGetInfo(size_t *free, size_t *total); + +/** + * \brief Allocates device memory + * + * Allocates \p bytesize bytes of linear memory on the device and returns in + * \p *dptr a pointer to the allocated memory. The allocated memory is suitably + * aligned for any kind of variable. The memory is not cleared. If \p bytesize + * is 0, ::cuMemAlloc() returns ::CUDA_ERROR_INVALID_VALUE. + * + * \param dptr - Returned device pointer + * \param bytesize - Requested allocation size in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMalloc + */ +CUresult CUDAAPI cuMemAlloc(CUdeviceptr *dptr, size_t bytesize); + +/** + * \brief Allocates pitched device memory + * + * Allocates at least \p WidthInBytes * \p Height bytes of linear memory on + * the device and returns in \p *dptr a pointer to the allocated memory. The + * function may pad the allocation to ensure that corresponding pointers in + * any given row will continue to meet the alignment requirements for + * coalescing as the address is updated from row to row. \p ElementSizeBytes + * specifies the size of the largest reads and writes that will be performed + * on the memory range. \p ElementSizeBytes may be 4, 8 or 16 (since coalesced + * memory transactions are not possible on other data sizes). If + * \p ElementSizeBytes is smaller than the actual read/write size of a kernel, + * the kernel will run correctly, but possibly at reduced speed. The pitch + * returned in \p *pPitch by ::cuMemAllocPitch() is the width in bytes of the + * allocation. The intended usage of pitch is as a separate parameter of the + * allocation, used to compute addresses within the 2D array. Given the row + * and column of an array element of type \b T, the address is computed as: + * \code + T* pElement = (T*)((char*)BaseAddress + Row * Pitch) + Column; + * \endcode + * + * The pitch returned by ::cuMemAllocPitch() is guaranteed to work with + * ::cuMemcpy2D() under all circumstances. For allocations of 2D arrays, it is + * recommended that programmers consider performing pitch allocations using + * ::cuMemAllocPitch(). Due to alignment restrictions in the hardware, this is + * especially true if the application will be performing 2D memory copies + * between different regions of device memory (whether linear memory or CUDA + * arrays). + * + * The byte alignment of the pitch returned by ::cuMemAllocPitch() is guaranteed + * to match or exceed the alignment requirement for texture binding with + * ::cuTexRefSetAddress2D(). + * + * \param dptr - Returned device pointer + * \param pPitch - Returned pitch of allocation in bytes + * \param WidthInBytes - Requested allocation width in bytes + * \param Height - Requested allocation height in rows + * \param ElementSizeBytes - Size of largest reads/writes for range + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMallocPitch + */ +CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr *dptr, size_t *pPitch, size_t WidthInBytes, size_t Height, unsigned int ElementSizeBytes); + +/** + * \brief Frees device memory + * + * Frees the memory space pointed to by \p dptr, which must have been returned + * by a previous call to ::cuMemAlloc() or ::cuMemAllocPitch(). + * + * \param dptr - Pointer to memory to free + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaFree + */ +CUresult CUDAAPI cuMemFree(CUdeviceptr dptr); + +/** + * \brief Get information on memory allocations + * + * Returns the base address in \p *pbase and size in \p *psize of the + * allocation by ::cuMemAlloc() or ::cuMemAllocPitch() that contains the input + * pointer \p dptr. Both parameters \p pbase and \p psize are optional. If one + * of them is NULL, it is ignored. + * + * \param pbase - Returned base address + * \param psize - Returned size of device memory allocation + * \param dptr - Device pointer to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_NOT_FOUND, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32 + */ +CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr *pbase, size_t *psize, CUdeviceptr dptr); + +/** + * \brief Allocates page-locked host memory + * + * Allocates \p bytesize bytes of host memory that is page-locked and + * accessible to the device. The driver tracks the virtual memory ranges + * allocated with this function and automatically accelerates calls to + * functions such as ::cuMemcpy(). Since the memory can be accessed directly by + * the device, it can be read or written with much higher bandwidth than + * pageable memory obtained with functions such as ::malloc(). Allocating + * excessive amounts of memory with ::cuMemAllocHost() may degrade system + * performance, since it reduces the amount of memory available to the system + * for paging. As a result, this function is best used sparingly to allocate + * staging areas for data exchange between host and device. + * + * Note all host memory allocated using ::cuMemHostAlloc() will automatically + * be immediately accessible to all contexts on all devices which support unified + * addressing (as may be queried using ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING). + * The device pointer that may be used to access this host memory from those + * contexts is always equal to the returned host pointer \p *pp. + * See \ref CUDA_UNIFIED for additional details. + * + * \param pp - Returned host pointer to page-locked memory + * \param bytesize - Requested allocation size in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMallocHost + */ +CUresult CUDAAPI cuMemAllocHost(void **pp, size_t bytesize); + +/** + * \brief Frees page-locked host memory + * + * Frees the memory space pointed to by \p p, which must have been returned by + * a previous call to ::cuMemAllocHost(). + * + * \param p - Pointer to memory to free + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaFreeHost + */ +CUresult CUDAAPI cuMemFreeHost(void *p); + +/** + * \brief Allocates page-locked host memory + * + * Allocates \p bytesize bytes of host memory that is page-locked and accessible + * to the device. The driver tracks the virtual memory ranges allocated with + * this function and automatically accelerates calls to functions such as + * ::cuMemcpyHtoD(). Since the memory can be accessed directly by the device, + * it can be read or written with much higher bandwidth than pageable memory + * obtained with functions such as ::malloc(). Allocating excessive amounts of + * pinned memory may degrade system performance, since it reduces the amount + * of memory available to the system for paging. As a result, this function is + * best used sparingly to allocate staging areas for data exchange between + * host and device. + * + * The \p Flags parameter enables different options to be specified that + * affect the allocation, as follows. + * + * - ::CU_MEMHOSTALLOC_PORTABLE: The memory returned by this call will be + * considered as pinned memory by all CUDA contexts, not just the one that + * performed the allocation. + * + * - ::CU_MEMHOSTALLOC_DEVICEMAP: Maps the allocation into the CUDA address + * space. The device pointer to the memory may be obtained by calling + * ::cuMemHostGetDevicePointer(). + * + * - ::CU_MEMHOSTALLOC_WRITECOMBINED: Allocates the memory as write-combined + * (WC). WC memory can be transferred across the PCI Express bus more + * quickly on some system configurations, but cannot be read efficiently by + * most CPUs. WC memory is a good option for buffers that will be written by + * the CPU and read by the GPU via mapped pinned memory or host->device + * transfers. + * + * All of these flags are orthogonal to one another: a developer may allocate + * memory that is portable, mapped and/or write-combined with no restrictions. + * + * The ::CU_MEMHOSTALLOC_DEVICEMAP flag may be specified on CUDA contexts for + * devices that do not support mapped pinned memory. The failure is deferred + * to ::cuMemHostGetDevicePointer() because the memory may be mapped into + * other CUDA contexts via the ::CU_MEMHOSTALLOC_PORTABLE flag. + * + * The memory allocated by this function must be freed with ::cuMemFreeHost(). + * + * Note all host memory allocated using ::cuMemHostAlloc() will automatically + * be immediately accessible to all contexts on all devices which support unified + * addressing (as may be queried using ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING). + * Unless the flag ::CU_MEMHOSTALLOC_WRITECOMBINED is specified, the device pointer + * that may be used to access this host memory from those contexts is always equal + * to the returned host pointer \p *pp. If the flag ::CU_MEMHOSTALLOC_WRITECOMBINED + * is specified, then the function ::cuMemHostGetDevicePointer() must be used + * to query the device pointer, even if the context supports unified addressing. + * See \ref CUDA_UNIFIED for additional details. + * + * \param pp - Returned host pointer to page-locked memory + * \param bytesize - Requested allocation size in bytes + * \param Flags - Flags for allocation request + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaHostAlloc + */ +CUresult CUDAAPI cuMemHostAlloc(void **pp, size_t bytesize, unsigned int Flags); + +/** + * \brief Passes back device pointer of mapped pinned memory + * + * Passes back the device pointer \p pdptr corresponding to the mapped, pinned + * host buffer \p p allocated by ::cuMemHostAlloc. + * + * ::cuMemHostGetDevicePointer() will fail if the ::CU_MEMHOSTALLOC_DEVICEMAP + * flag was not specified at the time the memory was allocated, or if the + * function is called on a GPU that does not support mapped pinned memory. + * + * For devices that have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, the memory + * can also be accessed from the device using the host pointer \p p. + * The device pointer returned by ::cuMemHostGetDevicePointer() may or may not + * match the original host pointer \p p and depends on the devices visible to the + * application. If all devices visible to the application have a non-zero value for the + * device attribute, the device pointer returned by ::cuMemHostGetDevicePointer() + * will match the original pointer \p p. If any device visible to the application + * has a zero value for the device attribute, the device pointer returned by + * ::cuMemHostGetDevicePointer() will not match the original host pointer \p p, + * but it will be suitable for use on all devices provided Unified Virtual Addressing + * is enabled. In such systems, it is valid to access the memory using either pointer + * on devices that have a non-zero value for the device attribute. Note however that + * such devices should access the memory using only one of the two pointers and not both. + * + * \p Flags provides for future releases. For now, it must be set to 0. + * + * \param pdptr - Returned device pointer + * \param p - Host pointer + * \param Flags - Options (must be 0) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaHostGetDevicePointer + */ +CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr *pdptr, void *p, unsigned int Flags); + +/** + * \brief Passes back flags that were used for a pinned allocation + * + * Passes back the flags \p pFlags that were specified when allocating + * the pinned host buffer \p p allocated by ::cuMemHostAlloc. + * + * ::cuMemHostGetFlags() will fail if the pointer does not reside in + * an allocation performed by ::cuMemAllocHost() or ::cuMemHostAlloc(). + * + * \param pFlags - Returned flags word + * \param p - Host pointer + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuMemAllocHost, + * ::cuMemHostAlloc, + * ::cudaHostGetFlags + */ +CUresult CUDAAPI cuMemHostGetFlags(unsigned int *pFlags, void *p); + +/** + * \brief Allocates memory that will be automatically managed by the Unified Memory system + * + * Allocates \p bytesize bytes of managed memory on the device and returns in + * \p *dptr a pointer to the allocated memory. If the device doesn't support + * allocating managed memory, ::CUDA_ERROR_NOT_SUPPORTED is returned. Support + * for managed memory can be queried using the device attribute + * ::CU_DEVICE_ATTRIBUTE_MANAGED_MEMORY. The allocated memory is suitably + * aligned for any kind of variable. The memory is not cleared. If \p bytesize + * is 0, ::cuMemAllocManaged returns ::CUDA_ERROR_INVALID_VALUE. The pointer + * is valid on the CPU and on all GPUs in the system that support managed memory. + * All accesses to this pointer must obey the Unified Memory programming model. + * + * \p flags specifies the default stream association for this allocation. + * \p flags must be one of ::CU_MEM_ATTACH_GLOBAL or ::CU_MEM_ATTACH_HOST. If + * ::CU_MEM_ATTACH_GLOBAL is specified, then this memory is accessible from + * any stream on any device. If ::CU_MEM_ATTACH_HOST is specified, then the + * allocation should not be accessed from devices that have a zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS; an explicit call to + * ::cuStreamAttachMemAsync will be required to enable access on such devices. + * + * If the association is later changed via ::cuStreamAttachMemAsync to + * a single stream, the default association as specifed during ::cuMemAllocManaged + * is restored when that stream is destroyed. For __managed__ variables, the + * default association is always ::CU_MEM_ATTACH_GLOBAL. Note that destroying a + * stream is an asynchronous operation, and as a result, the change to default + * association won't happen until all work in the stream has completed. + * + * Memory allocated with ::cuMemAllocManaged should be released with ::cuMemFree. + * + * Device memory oversubscription is possible for GPUs that have a non-zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Managed memory on + * such GPUs may be evicted from device memory to host memory at any time by the Unified + * Memory driver in order to make room for other allocations. + * + * In a multi-GPU system where all GPUs have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, managed memory may not be populated when this + * API returns and instead may be populated on access. In such systems, managed memory can + * migrate to any processor's memory at any time. The Unified Memory driver will employ heuristics to + * maintain data locality and prevent excessive page faults to the extent possible. The application + * can also guide the driver about memory usage patterns via ::cuMemAdvise. The application + * can also explicitly migrate memory to a desired processor's memory via + * ::cuMemPrefetchAsync. + * + * In a multi-GPU system where all of the GPUs have a zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS and all the GPUs have peer-to-peer support + * with each other, the physical storage for managed memory is created on the GPU which is active + * at the time ::cuMemAllocManaged is called. All other GPUs will reference the data at reduced + * bandwidth via peer mappings over the PCIe bus. The Unified Memory driver does not migrate + * memory among such GPUs. + * + * In a multi-GPU system where not all GPUs have peer-to-peer support with each other and + * where the value of the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS + * is zero for at least one of those GPUs, the location chosen for physical storage of managed + * memory is system-dependent. + * - On Linux, the location chosen will be device memory as long as the current set of active + * contexts are on devices that either have peer-to-peer support with each other or have a + * non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. + * If there is an active context on a GPU that does not have a non-zero value for that device + * attribute and it does not have peer-to-peer support with the other devices that have active + * contexts on them, then the location for physical storage will be 'zero-copy' or host memory. + * Note that this means that managed memory that is located in device memory is migrated to + * host memory if a new context is created on a GPU that doesn't have a non-zero value for + * the device attribute and does not support peer-to-peer with at least one of the other devices + * that has an active context. This in turn implies that context creation may fail if there is + * insufficient host memory to migrate all managed allocations. + * - On Windows, the physical storage is always created in 'zero-copy' or host memory. + * All GPUs will reference the data at reduced bandwidth over the PCIe bus. In these + * circumstances, use of the environment variable CUDA_VISIBLE_DEVICES is recommended to + * restrict CUDA to only use those GPUs that have peer-to-peer support. + * Alternatively, users can also set CUDA_MANAGED_FORCE_DEVICE_ALLOC to a + * non-zero value to force the driver to always use device memory for physical storage. + * When this environment variable is set to a non-zero value, all contexts created in + * that process on devices that support managed memory have to be peer-to-peer compatible + * with each other. Context creation will fail if a context is created on a device that + * supports managed memory and is not peer-to-peer compatible with any of the other + * managed memory supporting devices on which contexts were previously created, even if + * those contexts have been destroyed. These environment variables are described + * in the CUDA programming guide under the "CUDA environment variables" section. + * - On ARM, managed memory is not available on discrete gpu with Drive PX-2. + * + * \param dptr - Returned device pointer + * \param bytesize - Requested allocation size in bytes + * \param flags - Must be one of ::CU_MEM_ATTACH_GLOBAL or ::CU_MEM_ATTACH_HOST + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cuDeviceGetAttribute, ::cuStreamAttachMemAsync, + * ::cudaMallocManaged + */ +CUresult CUDAAPI cuMemAllocManaged(CUdeviceptr *dptr, size_t bytesize, unsigned int flags); + +/** + * \brief Returns a handle to a compute device + * + * Returns in \p *device a device handle given a PCI bus ID string. + * + * \param dev - Returned device handle + * + * \param pciBusId - String in one of the following forms: + * [domain]:[bus]:[device].[function] + * [domain]:[bus]:[device] + * [bus]:[device].[function] + * where \p domain, \p bus, \p device, and \p function are all hexadecimal values + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGet, + * ::cuDeviceGetAttribute, + * ::cuDeviceGetPCIBusId, + * ::cudaDeviceGetByPCIBusId + */ +CUresult CUDAAPI cuDeviceGetByPCIBusId(CUdevice *dev, const char *pciBusId); + +/** + * \brief Returns a PCI Bus Id string for the device + * + * Returns an ASCII string identifying the device \p dev in the NULL-terminated + * string pointed to by \p pciBusId. \p len specifies the maximum length of the + * string that may be returned. + * + * \param pciBusId - Returned identifier string for the device in the following format + * [domain]:[bus]:[device].[function] + * where \p domain, \p bus, \p device, and \p function are all hexadecimal values. + * pciBusId should be large enough to store 13 characters including the NULL-terminator. + * + * \param len - Maximum length of string to store in \p name + * + * \param dev - Device to get identifier string for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuDeviceGet, + * ::cuDeviceGetAttribute, + * ::cuDeviceGetByPCIBusId, + * ::cudaDeviceGetPCIBusId + */ +CUresult CUDAAPI cuDeviceGetPCIBusId(char *pciBusId, int len, CUdevice dev); + +/** + * \brief Gets an interprocess handle for a previously allocated event + * + * Takes as input a previously allocated event. This event must have been + * created with the ::CU_EVENT_INTERPROCESS and ::CU_EVENT_DISABLE_TIMING + * flags set. This opaque handle may be copied into other processes and + * opened with ::cuIpcOpenEventHandle to allow efficient hardware + * synchronization between GPU work in different processes. + * + * After the event has been opened in the importing process, + * ::cuEventRecord, ::cuEventSynchronize, ::cuStreamWaitEvent and + * ::cuEventQuery may be used in either process. Performing operations + * on the imported event after the exported event has been freed + * with ::cuEventDestroy will result in undefined behavior. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * + * \param pHandle - Pointer to a user allocated CUipcEventHandle + * in which to return the opaque event handle + * \param event - Event allocated with ::CU_EVENT_INTERPROCESS and + * ::CU_EVENT_DISABLE_TIMING flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuEventCreate, + * ::cuEventDestroy, + * ::cuEventSynchronize, + * ::cuEventQuery, + * ::cuStreamWaitEvent, + * ::cuIpcOpenEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcOpenMemHandle, + * ::cuIpcCloseMemHandle, + * ::cudaIpcGetEventHandle + */ +CUresult CUDAAPI cuIpcGetEventHandle(CUipcEventHandle *pHandle, CUevent event); + +/** + * \brief Opens an interprocess event handle for use in the current process + * + * Opens an interprocess event handle exported from another process with + * ::cuIpcGetEventHandle. This function returns a ::CUevent that behaves like + * a locally created event with the ::CU_EVENT_DISABLE_TIMING flag specified. + * This event must be freed with ::cuEventDestroy. + * + * Performing operations on the imported event after the exported event has + * been freed with ::cuEventDestroy will result in undefined behavior. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * + * \param phEvent - Returns the imported event + * \param handle - Interprocess handle to open + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuEventCreate, + * ::cuEventDestroy, + * ::cuEventSynchronize, + * ::cuEventQuery, + * ::cuStreamWaitEvent, + * ::cuIpcGetEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcOpenMemHandle, + * ::cuIpcCloseMemHandle, + * ::cudaIpcOpenEventHandle + */ +CUresult CUDAAPI cuIpcOpenEventHandle(CUevent *phEvent, CUipcEventHandle handle); + +/** + * \brief Gets an interprocess memory handle for an existing device memory + * allocation + * + * Takes a pointer to the base of an existing device memory allocation created + * with ::cuMemAlloc and exports it for use in another process. This is a + * lightweight operation and may be called multiple times on an allocation + * without adverse effects. + * + * If a region of memory is freed with ::cuMemFree and a subsequent call + * to ::cuMemAlloc returns memory with the same device address, + * ::cuIpcGetMemHandle will return a unique handle for the + * new memory. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * + * \param pHandle - Pointer to user allocated ::CUipcMemHandle to return + * the handle in. + * \param dptr - Base pointer to previously allocated device memory + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuIpcGetEventHandle, + * ::cuIpcOpenEventHandle, + * ::cuIpcOpenMemHandle, + * ::cuIpcCloseMemHandle, + * ::cudaIpcGetMemHandle + */ +CUresult CUDAAPI cuIpcGetMemHandle(CUipcMemHandle *pHandle, CUdeviceptr dptr); + +/** + * \brief Opens an interprocess memory handle exported from another process + * and returns a device pointer usable in the local process. + * + * Maps memory exported from another process with ::cuIpcGetMemHandle into + * the current device address space. For contexts on different devices + * ::cuIpcOpenMemHandle can attempt to enable peer access between the + * devices as if the user called ::cuCtxEnablePeerAccess. This behavior is + * controlled by the ::CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS flag. + * ::cuDeviceCanAccessPeer can determine if a mapping is possible. + * + * Contexts that may open ::CUipcMemHandles are restricted in the following way. + * ::CUipcMemHandles from each ::CUdevice in a given process may only be opened + * by one ::CUcontext per ::CUdevice per other process. + * + * If the memory handle has already been opened by the current context, the + * reference count on the handle is incremented by 1 and the existing device pointer + * is returned. + * + * Memory returned from ::cuIpcOpenMemHandle must be freed with + * ::cuIpcCloseMemHandle. + * + * Calling ::cuMemFree on an exported memory region before calling + * ::cuIpcCloseMemHandle in the importing context will result in undefined + * behavior. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * + * \param pdptr - Returned device pointer + * \param handle - ::CUipcMemHandle to open + * \param Flags - Flags for this operation. Must be specified as ::CU_IPC_MEM_LAZY_ENABLE_PEER_ACCESS + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_TOO_MANY_PEERS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \note No guarantees are made about the address returned in \p *pdptr. + * In particular, multiple processes may not receive the same address for the same \p handle. + * + * \sa + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuIpcGetEventHandle, + * ::cuIpcOpenEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcCloseMemHandle, + * ::cuCtxEnablePeerAccess, + * ::cuDeviceCanAccessPeer, + * ::cudaIpcOpenMemHandle + */ +CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int Flags); + +/** + * \brief Attempts to close memory mapped with ::cuIpcOpenMemHandle + * + * Decrements the reference count of the memory returned by ::cuIpcOpenMemHandle by 1. + * When the reference count reaches 0, this API unmaps the memory. The original allocation + * in the exporting process as well as imported mappings in other processes + * will be unaffected. + * + * Any resources used to enable peer access will be freed if this is the + * last mapping using them. + * + * IPC functionality is restricted to devices with support for unified + * addressing on Linux and Windows operating systems. + * IPC functionality on Windows is restricted to GPUs in TCC mode + * + * \param dptr - Device pointer returned by ::cuIpcOpenMemHandle + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_MAP_FAILED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \sa + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuIpcGetEventHandle, + * ::cuIpcOpenEventHandle, + * ::cuIpcGetMemHandle, + * ::cuIpcOpenMemHandle, + * ::cudaIpcCloseMemHandle + */ +CUresult CUDAAPI cuIpcCloseMemHandle(CUdeviceptr dptr); + +/** + * \brief Registers an existing host memory range for use by CUDA + * + * Page-locks the memory range specified by \p p and \p bytesize and maps it + * for the device(s) as specified by \p Flags. This memory range also is added + * to the same tracking mechanism as ::cuMemHostAlloc to automatically accelerate + * calls to functions such as ::cuMemcpyHtoD(). Since the memory can be accessed + * directly by the device, it can be read or written with much higher bandwidth + * than pageable memory that has not been registered. Page-locking excessive + * amounts of memory may degrade system performance, since it reduces the amount + * of memory available to the system for paging. As a result, this function is + * best used sparingly to register staging areas for data exchange between + * host and device. + * + * This function has limited support on Mac OS X. OS 10.7 or higher is required. + * + * The \p Flags parameter enables different options to be specified that + * affect the allocation, as follows. + * + * - ::CU_MEMHOSTREGISTER_PORTABLE: The memory returned by this call will be + * considered as pinned memory by all CUDA contexts, not just the one that + * performed the allocation. + * + * - ::CU_MEMHOSTREGISTER_DEVICEMAP: Maps the allocation into the CUDA address + * space. The device pointer to the memory may be obtained by calling + * ::cuMemHostGetDevicePointer(). + * + * - ::CU_MEMHOSTREGISTER_IOMEMORY: The pointer is treated as pointing to some + * I/O memory space, e.g. the PCI Express resource of a 3rd party device. + * + * - ::CU_MEMHOSTREGISTER_READ_ONLY: The pointer is treated as pointing to memory + * that is considered read-only by the device. On platforms without + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, this flag is + * required in order to register memory mapped to the CPU as read-only. Support + * for the use of this flag can be queried from the device attribute + * ::CU_DEVICE_ATTRIBUTE_READ_ONLY_HOST_REGISTER_SUPPORTED. Using this flag with + * a current context associated with a device that does not have this attribute + * set will cause ::cuMemHostRegister to error with CUDA_ERROR_NOT_SUPPORTED. + * + * All of these flags are orthogonal to one another: a developer may page-lock + * memory that is portable or mapped with no restrictions. + * + * The ::CU_MEMHOSTREGISTER_DEVICEMAP flag may be specified on CUDA contexts for + * devices that do not support mapped pinned memory. The failure is deferred + * to ::cuMemHostGetDevicePointer() because the memory may be mapped into + * other CUDA contexts via the ::CU_MEMHOSTREGISTER_PORTABLE flag. + * + * For devices that have a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, the memory + * can also be accessed from the device using the host pointer \p p. + * The device pointer returned by ::cuMemHostGetDevicePointer() may or may not + * match the original host pointer \p ptr and depends on the devices visible to the + * application. If all devices visible to the application have a non-zero value for the + * device attribute, the device pointer returned by ::cuMemHostGetDevicePointer() + * will match the original pointer \p ptr. If any device visible to the application + * has a zero value for the device attribute, the device pointer returned by + * ::cuMemHostGetDevicePointer() will not match the original host pointer \p ptr, + * but it will be suitable for use on all devices provided Unified Virtual Addressing + * is enabled. In such systems, it is valid to access the memory using either pointer + * on devices that have a non-zero value for the device attribute. Note however that + * such devices should access the memory using only of the two pointers and not both. + * + * The memory page-locked by this function must be unregistered with + * ::cuMemHostUnregister(). + * + * \param p - Host pointer to memory to page-lock + * \param bytesize - Size in bytes of the address range to page-lock + * \param Flags - Flags for allocation request + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_HOST_MEMORY_ALREADY_REGISTERED, + * ::CUDA_ERROR_NOT_PERMITTED, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa + * ::cuMemHostUnregister, + * ::cuMemHostGetFlags, + * ::cuMemHostGetDevicePointer, + * ::cudaHostRegister + */ +CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, unsigned int Flags); + +/** + * \brief Unregisters a memory range that was registered with cuMemHostRegister. + * + * Unmaps the memory range whose base address is specified by \p p, and makes + * it pageable again. + * + * The base address must be the same one specified to ::cuMemHostRegister(). + * + * \param p - Host pointer to memory to unregister + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_HOST_MEMORY_NOT_REGISTERED, + * \notefnerr + * + * \sa + * ::cuMemHostRegister, + * ::cudaHostUnregister + */ +CUresult CUDAAPI cuMemHostUnregister(void *p); + +/** + * \brief Copies memory + * + * Copies data between two pointers. + * \p dst and \p src are base pointers of the destination and source, respectively. + * \p ByteCount specifies the number of bytes to copy. + * Note that this function infers the type of the transfer (host to host, host to + * device, device to device, or device to host) from the pointer values. This + * function is only allowed in contexts which support unified addressing. + * + * \param dst - Destination unified virtual address space pointer + * \param src - Source unified virtual address space pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyToSymbol, + * ::cudaMemcpyFromSymbol + */ +CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount); + +/** + * \brief Copies device memory between two contexts + * + * Copies from device memory in one context to device memory in another + * context. \p dstDevice is the base device pointer of the destination memory + * and \p dstContext is the destination context. \p srcDevice is the base + * device pointer of the source memory and \p srcContext is the source pointer. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param dstContext - Destination context + * \param srcDevice - Source device pointer + * \param srcContext - Source context + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuMemcpyDtoD, ::cuMemcpy3DPeer, ::cuMemcpyDtoDAsync, ::cuMemcpyPeerAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpyPeer + */ +CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, CUdeviceptr srcDevice, CUcontext srcContext, size_t ByteCount); + +/** + * \brief Copies memory from Host to Device + * + * Copies from host memory to device memory. \p dstDevice and \p srcHost are + * the base addresses of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyToSymbol + */ +CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); + +/** + * \brief Copies memory from Device to Host + * + * Copies from device to host memory. \p dstHost and \p srcDevice specify the + * base pointers of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstHost - Destination host pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyFromSymbol + */ +CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); + +/** + * \brief Copies memory from Device to Device + * + * Copies from device memory to device memory. \p dstDevice and \p srcDevice + * are the base pointers of the destination and source, respectively. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy, + * ::cudaMemcpyToSymbol, + * ::cudaMemcpyFromSymbol + */ +CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr dstDevice, CUdeviceptr srcDevice, size_t ByteCount); + +/** + * \brief Copies memory from Device to Array + * + * Copies from device memory to a 1D CUDA array. \p dstArray and \p dstOffset + * specify the CUDA array handle and starting index of the destination data. + * \p srcDevice specifies the base pointer of the source. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyToArray + */ +CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, size_t dstOffset, CUdeviceptr srcDevice, size_t ByteCount); + +/** + * \brief Copies memory from Array to Device + * + * Copies from one 1D CUDA array to device memory. \p dstDevice specifies the + * base pointer of the destination and must be naturally aligned with the CUDA + * array elements. \p srcArray and \p srcOffset specify the CUDA array handle + * and the offset in bytes into the array where the copy is to begin. + * \p ByteCount specifies the number of bytes to copy and must be evenly + * divisible by the array element size. + * + * \param dstDevice - Destination device pointer + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyFromArray + */ +CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr dstDevice, CUarray srcArray, size_t srcOffset, size_t ByteCount); + +/** + * \brief Copies memory from Host to Array + * + * Copies from host memory to a 1D CUDA array. \p dstArray and \p dstOffset + * specify the CUDA array handle and starting offset in bytes of the destination + * data. \p pSrc specifies the base address of the source. \p ByteCount specifies + * the number of bytes to copy. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyToArray + */ +CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, size_t dstOffset, const void *srcHost, size_t ByteCount); + +/** + * \brief Copies memory from Array to Host + * + * Copies from one 1D CUDA array to host memory. \p dstHost specifies the base + * pointer of the destination. \p srcArray and \p srcOffset specify the CUDA + * array handle and starting offset in bytes of the source data. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstHost - Destination device pointer + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyFromArray + */ +CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, size_t srcOffset, size_t ByteCount); + +/** + * \brief Copies memory from Array to Array + * + * Copies from one 1D CUDA array to another. \p dstArray and \p srcArray + * specify the handles of the destination and source CUDA arrays for the copy, + * respectively. \p dstOffset and \p srcOffset specify the destination and + * source offsets in bytes into the CUDA arrays. \p ByteCount is the number of + * bytes to be copied. The size of the elements in the CUDA arrays need not be + * the same format, but the elements must be the same size; and count must be + * evenly divisible by that size. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpyArrayToArray + */ +CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, size_t dstOffset, CUarray srcArray, size_t srcOffset, size_t ByteCount); + +/** + * \brief Copies memory for 2D arrays + * + * Perform a 2D memory copy according to the parameters specified in \p pCopy. + * The ::CUDA_MEMCPY2D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY2D_st { + unsigned int srcXInBytes, srcY; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; + + unsigned int dstXInBytes, dstY; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; + + unsigned int WidthInBytes; + unsigned int Height; + } CUDA_MEMCPY2D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost and ::srcPitch + * specify the (host) base address of the source data and the bytes per row to + * apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice and ::srcPitch + * specify the (device) base address of the source data and the bytes per row + * to apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice and ::srcPitch are + * ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice and ::dstPitch are + * ignored. + * + * - ::srcXInBytes and ::srcY specify the base address of the source data for + * the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+srcY*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+srcY*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - ::dstXInBytes and ::dstY specify the base address of the destination data + * for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+dstY*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+dstY*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes and ::Height specify the width (in bytes) and height of + * the 2D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * + * \par + * ::cuMemcpy2D() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). ::cuMemAllocPitch() passes back + * pitches that always work with ::cuMemcpy2D(). On intra-device memory copies + * (device to device, CUDA array to device, CUDA array to CUDA array), + * ::cuMemcpy2D() may fail for pitches not computed by ::cuMemAllocPitch(). + * ::cuMemcpy2DUnaligned() does not have this restriction, but may run + * significantly slower in the cases where ::cuMemcpy2D() would have returned + * an error code. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy2D, + * ::cudaMemcpy2DToArray, + * ::cudaMemcpy2DFromArray + */ +CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D *pCopy); + +/** + * \brief Copies memory for 2D arrays + * + * Perform a 2D memory copy according to the parameters specified in \p pCopy. + * The ::CUDA_MEMCPY2D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY2D_st { + unsigned int srcXInBytes, srcY; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; + unsigned int dstXInBytes, dstY; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; + unsigned int WidthInBytes; + unsigned int Height; + } CUDA_MEMCPY2D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost and ::srcPitch + * specify the (host) base address of the source data and the bytes per row to + * apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice and ::srcPitch + * specify the (device) base address of the source data and the bytes per row + * to apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice and ::srcPitch are + * ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice and ::dstPitch are + * ignored. + * + * - ::srcXInBytes and ::srcY specify the base address of the source data for + * the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+srcY*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+srcY*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - ::dstXInBytes and ::dstY specify the base address of the destination data + * for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+dstY*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+dstY*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes and ::Height specify the width (in bytes) and height of + * the 2D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * + * \par + * ::cuMemcpy2D() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). ::cuMemAllocPitch() passes back + * pitches that always work with ::cuMemcpy2D(). On intra-device memory copies + * (device to device, CUDA array to device, CUDA array to CUDA array), + * ::cuMemcpy2D() may fail for pitches not computed by ::cuMemAllocPitch(). + * ::cuMemcpy2DUnaligned() does not have this restriction, but may run + * significantly slower in the cases where ::cuMemcpy2D() would have returned + * an error code. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy2D, + * ::cudaMemcpy2DToArray, + * ::cudaMemcpy2DFromArray + */ +CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D *pCopy); + +/** + * \brief Copies memory for 3D arrays + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. The ::CUDA_MEMCPY3D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY3D_st { + + unsigned int srcXInBytes, srcY, srcZ; + unsigned int srcLOD; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; // ignored when src is array + unsigned int srcHeight; // ignored when src is array; may be 0 if Depth==1 + + unsigned int dstXInBytes, dstY, dstZ; + unsigned int dstLOD; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; // ignored when dst is array + unsigned int dstHeight; // ignored when dst is array; may be 0 if Depth==1 + + unsigned int WidthInBytes; + unsigned int Height; + unsigned int Depth; + } CUDA_MEMCPY3D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost, ::srcPitch and + * ::srcHeight specify the (host) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice, ::srcPitch and + * ::srcHeight specify the (device) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice, ::srcPitch and + * ::srcHeight are ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data, the bytes per row, + * and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data, the bytes per + * row, and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice, ::dstPitch and + * ::dstHeight are ignored. + * + * - ::srcXInBytes, ::srcY and ::srcZ specify the base address of the source + * data for the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+(srcZ*srcHeight+srcY)*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+(srcZ*srcHeight+srcY)*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - dstXInBytes, ::dstY and ::dstZ specify the base address of the + * destination data for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+(dstZ*dstHeight+dstY)*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+(dstZ*dstHeight+dstY)*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes, ::Height and ::Depth specify the width (in bytes), height + * and depth of the 3D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcHeight must be greater than or equal to ::Height + + * ::srcY, and ::dstHeight must be greater than or equal to ::Height + ::dstY. + * + * \par + * ::cuMemcpy3D() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). + * + * The ::srcLOD and ::dstLOD members of the ::CUDA_MEMCPY3D structure must be + * set to 0. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMemcpy3D + */ +CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D *pCopy); + +/** + * \brief Copies memory between contexts + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. See the definition of the ::CUDA_MEMCPY3D_PEER structure + * for documentation of its parameters. + * + * \param pCopy - Parameters for the memory copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_sync + * + * \sa ::cuMemcpyDtoD, ::cuMemcpyPeer, ::cuMemcpyDtoDAsync, ::cuMemcpyPeerAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpy3DPeer + */ +CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy); + +/** + * \brief Copies memory asynchronously + * + * Copies data between two pointers. + * \p dst and \p src are base pointers of the destination and source, respectively. + * \p ByteCount specifies the number of bytes to copy. + * Note that this function infers the type of the transfer (host to host, host to + * device, device to device, or device to host) from the pointer values. This + * function is only allowed in contexts which support unified addressing. + * + * \param dst - Destination unified virtual address space pointer + * \param src - Source unified virtual address space pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyToSymbolAsync, + * ::cudaMemcpyFromSymbolAsync + */ +CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies device memory between two contexts asynchronously. + * + * Copies from device memory in one context to device memory in another + * context. \p dstDevice is the base device pointer of the destination memory + * and \p dstContext is the destination context. \p srcDevice is the base + * device pointer of the source memory and \p srcContext is the source pointer. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param dstContext - Destination context + * \param srcDevice - Source device pointer + * \param srcContext - Source context + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpyDtoD, ::cuMemcpyPeer, ::cuMemcpy3DPeer, ::cuMemcpyDtoDAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpyPeerAsync + */ +CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, CUdeviceptr srcDevice, CUcontext srcContext, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Host to Device + * + * Copies from host memory to device memory. \p dstDevice and \p srcHost are + * the base addresses of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyToSymbolAsync + */ +CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Device to Host + * + * Copies from device to host memory. \p dstHost and \p srcDevice specify the + * base pointers of the destination and source, respectively. \p ByteCount + * specifies the number of bytes to copy. + * + * \param dstHost - Destination host pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyFromSymbolAsync + */ +CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Device to Device + * + * Copies from device memory to device memory. \p dstDevice and \p srcDevice + * are the base pointers of the destination and source, respectively. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstDevice - Destination device pointer + * \param srcDevice - Source device pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyAsync, + * ::cudaMemcpyToSymbolAsync, + * ::cudaMemcpyFromSymbolAsync + */ +CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr dstDevice, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Host to Array + * + * Copies from host memory to a 1D CUDA array. \p dstArray and \p dstOffset + * specify the CUDA array handle and starting offset in bytes of the + * destination data. \p srcHost specifies the base address of the source. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstArray - Destination array + * \param dstOffset - Offset in bytes of destination array + * \param srcHost - Source host pointer + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyToArrayAsync + */ +CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, size_t dstOffset, const void *srcHost, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory from Array to Host + * + * Copies from one 1D CUDA array to host memory. \p dstHost specifies the base + * pointer of the destination. \p srcArray and \p srcOffset specify the CUDA + * array handle and starting offset in bytes of the source data. + * \p ByteCount specifies the number of bytes to copy. + * + * \param dstHost - Destination pointer + * \param srcArray - Source array + * \param srcOffset - Offset in bytes of source array + * \param ByteCount - Size of memory copy in bytes + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * \note_memcpy + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpyFromArrayAsync + */ +CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, size_t srcOffset, size_t ByteCount, CUstream hStream); + +/** + * \brief Copies memory for 2D arrays + * + * Perform a 2D memory copy according to the parameters specified in \p pCopy. + * The ::CUDA_MEMCPY2D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY2D_st { + unsigned int srcXInBytes, srcY; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; + unsigned int dstXInBytes, dstY; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; + unsigned int WidthInBytes; + unsigned int Height; + } CUDA_MEMCPY2D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost and ::srcPitch + * specify the (host) base address of the source data and the bytes per row to + * apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice and ::srcPitch + * specify the (device) base address of the source data and the bytes per row + * to apply. ::srcArray is ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice and ::srcPitch are + * ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data and the bytes per + * row to apply. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice and ::dstPitch are + * ignored. + * + * - ::srcXInBytes and ::srcY specify the base address of the source data for + * the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+srcY*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+srcY*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - ::dstXInBytes and ::dstY specify the base address of the destination data + * for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+dstY*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+dstY*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes and ::Height specify the width (in bytes) and height of + * the 2D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcHeight must be greater than or equal to ::Height + + * ::srcY, and ::dstHeight must be greater than or equal to ::Height + ::dstY. + * + * \par + * ::cuMemcpy2DAsync() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). ::cuMemAllocPitch() passes back + * pitches that always work with ::cuMemcpy2D(). On intra-device memory copies + * (device to device, CUDA array to device, CUDA array to CUDA array), + * ::cuMemcpy2DAsync() may fail for pitches not computed by ::cuMemAllocPitch(). + * + * \param pCopy - Parameters for the memory copy + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpy2DAsync, + * ::cudaMemcpy2DToArrayAsync, + * ::cudaMemcpy2DFromArrayAsync + */ +CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D *pCopy, CUstream hStream); + +/** + * \brief Copies memory for 3D arrays + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. The ::CUDA_MEMCPY3D structure is defined as: + * + * \code + typedef struct CUDA_MEMCPY3D_st { + + unsigned int srcXInBytes, srcY, srcZ; + unsigned int srcLOD; + CUmemorytype srcMemoryType; + const void *srcHost; + CUdeviceptr srcDevice; + CUarray srcArray; + unsigned int srcPitch; // ignored when src is array + unsigned int srcHeight; // ignored when src is array; may be 0 if Depth==1 + + unsigned int dstXInBytes, dstY, dstZ; + unsigned int dstLOD; + CUmemorytype dstMemoryType; + void *dstHost; + CUdeviceptr dstDevice; + CUarray dstArray; + unsigned int dstPitch; // ignored when dst is array + unsigned int dstHeight; // ignored when dst is array; may be 0 if Depth==1 + + unsigned int WidthInBytes; + unsigned int Height; + unsigned int Depth; + } CUDA_MEMCPY3D; + * \endcode + * where: + * - ::srcMemoryType and ::dstMemoryType specify the type of memory of the + * source and destination, respectively; ::CUmemorytype_enum is defined as: + * + * \code + typedef enum CUmemorytype_enum { + CU_MEMORYTYPE_HOST = 0x01, + CU_MEMORYTYPE_DEVICE = 0x02, + CU_MEMORYTYPE_ARRAY = 0x03, + CU_MEMORYTYPE_UNIFIED = 0x04 + } CUmemorytype; + * \endcode + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::srcDevice and ::srcPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::srcArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_HOST, ::srcHost, ::srcPitch and + * ::srcHeight specify the (host) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_DEVICE, ::srcDevice, ::srcPitch and + * ::srcHeight specify the (device) base address of the source data, the bytes + * per row, and the height of each 2D slice of the 3D array. ::srcArray is + * ignored. + * + * \par + * If ::srcMemoryType is ::CU_MEMORYTYPE_ARRAY, ::srcArray specifies the + * handle of the source data. ::srcHost, ::srcDevice, ::srcPitch and + * ::srcHeight are ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_UNIFIED, ::dstDevice and ::dstPitch + * specify the (unified virtual address space) base address of the source data + * and the bytes per row to apply. ::dstArray is ignored. + * This value may be used only if unified addressing is supported in the calling + * context. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_HOST, ::dstHost and ::dstPitch + * specify the (host) base address of the destination data, the bytes per row, + * and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_DEVICE, ::dstDevice and ::dstPitch + * specify the (device) base address of the destination data, the bytes per + * row, and the height of each 2D slice of the 3D array. ::dstArray is ignored. + * + * \par + * If ::dstMemoryType is ::CU_MEMORYTYPE_ARRAY, ::dstArray specifies the + * handle of the destination data. ::dstHost, ::dstDevice, ::dstPitch and + * ::dstHeight are ignored. + * + * - ::srcXInBytes, ::srcY and ::srcZ specify the base address of the source + * data for the copy. + * + * \par + * For host pointers, the starting address is + * \code + void* Start = (void*)((char*)srcHost+(srcZ*srcHeight+srcY)*srcPitch + srcXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr Start = srcDevice+(srcZ*srcHeight+srcY)*srcPitch+srcXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::srcXInBytes must be evenly divisible by the array + * element size. + * + * - dstXInBytes, ::dstY and ::dstZ specify the base address of the + * destination data for the copy. + * + * \par + * For host pointers, the base address is + * \code + void* dstStart = (void*)((char*)dstHost+(dstZ*dstHeight+dstY)*dstPitch + dstXInBytes); + * \endcode + * + * \par + * For device pointers, the starting address is + * \code + CUdeviceptr dstStart = dstDevice+(dstZ*dstHeight+dstY)*dstPitch+dstXInBytes; + * \endcode + * + * \par + * For CUDA arrays, ::dstXInBytes must be evenly divisible by the array + * element size. + * + * - ::WidthInBytes, ::Height and ::Depth specify the width (in bytes), height + * and depth of the 3D copy being performed. + * - If specified, ::srcPitch must be greater than or equal to ::WidthInBytes + + * ::srcXInBytes, and ::dstPitch must be greater than or equal to + * ::WidthInBytes + dstXInBytes. + * - If specified, ::srcHeight must be greater than or equal to ::Height + + * ::srcY, and ::dstHeight must be greater than or equal to ::Height + ::dstY. + * + * \par + * ::cuMemcpy3DAsync() returns an error if any pitch is greater than the maximum + * allowed (::CU_DEVICE_ATTRIBUTE_MAX_PITCH). + * + * The ::srcLOD and ::dstLOD members of the ::CUDA_MEMCPY3D structure must be + * set to 0. + * + * \param pCopy - Parameters for the memory copy + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemcpy3DAsync + */ +CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D *pCopy, CUstream hStream); + +/** + * \brief Copies memory between contexts asynchronously. + * + * Perform a 3D memory copy according to the parameters specified in + * \p pCopy. See the definition of the ::CUDA_MEMCPY3D_PEER structure + * for documentation of its parameters. + * + * \param pCopy - Parameters for the memory copy + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpyDtoD, ::cuMemcpyPeer, ::cuMemcpyDtoDAsync, ::cuMemcpyPeerAsync, + * ::cuMemcpy3DPeerAsync, + * ::cudaMemcpy3DPeerAsync + */ +CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, CUstream hStream); + +/** + * \brief Initializes device memory + * + * Sets the memory range of \p N 8-bit values to the specified value + * \p uc. + * + * \param dstDevice - Destination device pointer + * \param uc - Value to set + * \param N - Number of elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset + */ +CUresult CUDAAPI cuMemsetD8(CUdeviceptr dstDevice, unsigned char uc, size_t N); + +/** + * \brief Initializes device memory + * + * Sets the memory range of \p N 16-bit values to the specified value + * \p us. The \p dstDevice pointer must be two byte aligned. + * + * \param dstDevice - Destination device pointer + * \param us - Value to set + * \param N - Number of elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset + */ +CUresult CUDAAPI cuMemsetD16(CUdeviceptr dstDevice, unsigned short us, size_t N); + +/** + * \brief Initializes device memory + * + * Sets the memory range of \p N 32-bit values to the specified value + * \p ui. The \p dstDevice pointer must be four byte aligned. + * + * \param dstDevice - Destination device pointer + * \param ui - Value to set + * \param N - Number of elements + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32Async, + * ::cudaMemset + */ +CUresult CUDAAPI cuMemsetD32(CUdeviceptr dstDevice, unsigned int ui, size_t N); + +/** + * \brief Initializes device memory + * + * Sets the 2D memory range of \p Width 8-bit values to the specified value + * \p uc. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) + * \param uc - Value to set + * \param Width - Width of row + * \param Height - Number of rows + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2D + */ +CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr dstDevice, size_t dstPitch, unsigned char uc, size_t Width, size_t Height); + +/** + * \brief Initializes device memory + * + * Sets the 2D memory range of \p Width 16-bit values to the specified value + * \p us. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be two byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) + * \param us - Value to set + * \param Width - Width of row + * \param Height - Number of rows + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2D + */ +CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr dstDevice, size_t dstPitch, unsigned short us, size_t Width, size_t Height); + +/** + * \brief Initializes device memory + * + * Sets the 2D memory range of \p Width 32-bit values to the specified value + * \p ui. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be four byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) + * \param ui - Value to set + * \param Width - Width of row + * \param Height - Number of rows + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2D + */ +CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr dstDevice, size_t dstPitch, unsigned int ui, size_t Width, size_t Height); + +/** + * \brief Sets device memory + * + * Sets the memory range of \p N 8-bit values to the specified value + * \p uc. + * + * \param dstDevice - Destination device pointer + * \param uc - Value to set + * \param N - Number of elements + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemsetAsync + */ +CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, size_t N, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the memory range of \p N 16-bit values to the specified value + * \p us. The \p dstDevice pointer must be two byte aligned. + * + * \param dstDevice - Destination device pointer + * \param us - Value to set + * \param N - Number of elements + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemsetAsync + */ +CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, size_t N, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the memory range of \p N 32-bit values to the specified value + * \p ui. The \p dstDevice pointer must be four byte aligned. + * + * \param dstDevice - Destination device pointer + * \param ui - Value to set + * \param N - Number of elements + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, ::cuMemsetD32, + * ::cudaMemsetAsync + */ +CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, size_t N, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the 2D memory range of \p Width 8-bit values to the specified value + * \p uc. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) + * \param uc - Value to set + * \param Width - Width of row + * \param Height - Number of rows + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2DAsync + */ +CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned char uc, size_t Width, size_t Height, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the 2D memory range of \p Width 16-bit values to the specified value + * \p us. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be two byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) + * \param us - Value to set + * \param Width - Width of row + * \param Height - Number of rows + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D32, ::cuMemsetD2D32Async, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2DAsync + */ +CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned short us, size_t Width, size_t Height, CUstream hStream); + +/** + * \brief Sets device memory + * + * Sets the 2D memory range of \p Width 32-bit values to the specified value + * \p ui. \p Height specifies the number of rows to set, and \p dstPitch + * specifies the number of bytes between each row. The \p dstDevice pointer + * and \p dstPitch offset must be four byte aligned. This function performs + * fastest when the pitch is one that has been passed back by + * ::cuMemAllocPitch(). + * + * \param dstDevice - Destination device pointer + * \param dstPitch - Pitch of destination device pointer(Unused if \p Height is 1) + * \param ui - Value to set + * \param Width - Width of row + * \param Height - Number of rows + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * \note_memset + * \note_null_stream + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D8Async, + * ::cuMemsetD2D16, ::cuMemsetD2D16Async, ::cuMemsetD2D32, + * ::cuMemsetD8, ::cuMemsetD8Async, ::cuMemsetD16, ::cuMemsetD16Async, + * ::cuMemsetD32, ::cuMemsetD32Async, + * ::cudaMemset2DAsync + */ +CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned int ui, size_t Width, size_t Height, CUstream hStream); + +/** + * \brief Creates a 1D or 2D CUDA array + * + * Creates a CUDA array according to the ::CUDA_ARRAY_DESCRIPTOR structure + * \p pAllocateArray and returns a handle to the new CUDA array in \p *pHandle. + * The ::CUDA_ARRAY_DESCRIPTOR is defined as: + * + * \code + typedef struct { + unsigned int Width; + unsigned int Height; + CUarray_format Format; + unsigned int NumChannels; + } CUDA_ARRAY_DESCRIPTOR; + * \endcode + * where: + * + * - \p Width, and \p Height are the width, and height of the CUDA array (in + * elements); the CUDA array is one-dimensional if height is 0, two-dimensional + * otherwise; + * - ::Format specifies the format of the elements; ::CUarray_format is + * defined as: + * \code + typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, + CU_AD_FORMAT_SIGNED_INT8 = 0x08, + CU_AD_FORMAT_SIGNED_INT16 = 0x09, + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, + CU_AD_FORMAT_HALF = 0x10, + CU_AD_FORMAT_FLOAT = 0x20 + } CUarray_format; + * \endcode + * - \p NumChannels specifies the number of packed components per CUDA array + * element; it may be 1, 2, or 4; + * + * Here are examples of CUDA array descriptions: + * + * Description for a CUDA array of 2048 floats: + * \code + CUDA_ARRAY_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 2048; + desc.Height = 1; + * \endcode + * + * Description for a 64 x 64 CUDA array of floats: + * \code + CUDA_ARRAY_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 64; + desc.Height = 64; + * \endcode + * + * Description for a \p width x \p height CUDA array of 64-bit, 4x16-bit + * float16's: + * \code + CUDA_ARRAY_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_HALF; + desc.NumChannels = 4; + desc.Width = width; + desc.Height = height; + * \endcode + * + * Description for a \p width x \p height CUDA array of 16-bit elements, each + * of which is two 8-bit unsigned chars: + * \code + CUDA_ARRAY_DESCRIPTOR arrayDesc; + desc.Format = CU_AD_FORMAT_UNSIGNED_INT8; + desc.NumChannels = 2; + desc.Width = width; + desc.Height = height; + * \endcode + * + * \param pHandle - Returned array + * \param pAllocateArray - Array descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMallocArray + */ +CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, const CUDA_ARRAY_DESCRIPTOR *pAllocateArray); + +/** + * \brief Get a 1D or 2D CUDA array descriptor + * + * Returns in \p *pArrayDescriptor a descriptor containing information on the + * format and dimensions of the CUDA array \p hArray. It is useful for + * subroutines that have been passed a CUDA array, but need to know the CUDA + * array parameters for validation or other purposes. + * + * \param pArrayDescriptor - Returned array descriptor + * \param hArray - Array to get descriptor of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaArrayGetInfo + */ +CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR *pArrayDescriptor, CUarray hArray); + +/** + * \brief Returns the layout properties of a sparse CUDA array + * + * Returns the layout properties of a sparse CUDA array in \p sparseProperties + * If the CUDA array is not allocated with flag ::CUDA_ARRAY3D_SPARSE + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * If the returned value in ::CUDA_ARRAY_SPARSE_PROPERTIES::flags contains ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL, + * then ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize represents the total size of the array. Otherwise, it will be zero. + * Also, the returned value in ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailFirstLevel is always zero. + * Note that the \p array must have been allocated using ::cuArrayCreate or ::cuArray3DCreate. For CUDA arrays obtained + * using ::cuMipmappedArrayGetLevel, ::CUDA_ERROR_INVALID_VALUE will be returned. Instead, ::cuMipmappedArrayGetSparseProperties + * must be used to obtain the sparse properties of the entire CUDA mipmapped array to which \p array belongs to. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] sparseProperties - Pointer to ::CUDA_ARRAY_SPARSE_PROPERTIES + * \param[in] array - CUDA array to get the sparse properties of + * \sa ::cuMipmappedArrayGetSparseProperties, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuArrayGetSparseProperties(CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUarray array); + +/** + * \brief Returns the layout properties of a sparse CUDA mipmapped array + * + * Returns the sparse array layout properties in \p sparseProperties + * If the CUDA mipmapped array is not allocated with flag ::CUDA_ARRAY3D_SPARSE + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * For non-layered CUDA mipmapped arrays, ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize returns the + * size of the mip tail region. The mip tail region includes all mip levels whose width, height or depth + * is less than that of the tile. + * For layered CUDA mipmapped arrays, if ::CUDA_ARRAY_SPARSE_PROPERTIES::flags contains ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL, + * then ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize specifies the size of the mip tail of all layers combined. + * Otherwise, ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize specifies mip tail size per layer. + * The returned value of ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailFirstLevel is valid only if ::CUDA_ARRAY_SPARSE_PROPERTIES::miptailSize is non-zero. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] sparseProperties - Pointer to ::CUDA_ARRAY_SPARSE_PROPERTIES + * \param[in] mipmap - CUDA mipmapped array to get the sparse properties of + * \sa ::cuArrayGetSparseProperties, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuMipmappedArrayGetSparseProperties(CUDA_ARRAY_SPARSE_PROPERTIES *sparseProperties, CUmipmappedArray mipmap); + + +/** + * \brief Returns the memory requirements of a CUDA array + * + * Returns the memory requirements of a CUDA array in \p memoryRequirements + * If the CUDA array is not allocated with flag ::CUDA_ARRAY3D_DEFERRED_MAPPING + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::size + * represents the total size of the CUDA array. + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::alignment + * represents the alignment necessary for mapping the CUDA array. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] memoryRequirements - Pointer to ::CUDA_ARRAY_MEMORY_REQUIREMENTS + * \param[in] array - CUDA array to get the memory requirements of + * \param[in] device - Device to get the memory requirements for + * \sa ::cuMipmappedArrayGetMemoryRequirements, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuArrayGetMemoryRequirements(CUDA_ARRAY_MEMORY_REQUIREMENTS *memoryRequirements, CUarray array, CUdevice device); + +/** + * \brief Returns the memory requirements of a CUDA mipmapped array + * + * Returns the memory requirements of a CUDA mipmapped array in \p memoryRequirements + * If the CUDA mipmapped array is not allocated with flag ::CUDA_ARRAY3D_DEFERRED_MAPPING + * ::CUDA_ERROR_INVALID_VALUE will be returned. + * + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::size + * represents the total size of the CUDA mipmapped array. + * The returned value in ::CUDA_ARRAY_MEMORY_REQUIREMENTS::alignment + * represents the alignment necessary for mapping the CUDA mipmapped + * array. + * + * \return + * ::CUDA_SUCCESS + * ::CUDA_ERROR_INVALID_VALUE + * + * \param[out] memoryRequirements - Pointer to ::CUDA_ARRAY_MEMORY_REQUIREMENTS + * \param[in] mipmap - CUDA mipmapped array to get the memory requirements of + * \param[in] device - Device to get the memory requirements for + * \sa ::cuArrayGetMemoryRequirements, ::cuMemMapArrayAsync + */ +CUresult CUDAAPI cuMipmappedArrayGetMemoryRequirements(CUDA_ARRAY_MEMORY_REQUIREMENTS *memoryRequirements, CUmipmappedArray mipmap, CUdevice device); + + +/** + * \brief Gets a CUDA array plane from a CUDA array + * + * Returns in \p pPlaneArray a CUDA array that represents a single format plane + * of the CUDA array \p hArray. + * + * If \p planeIdx is greater than the maximum number of planes in this array or if the array does + * not have a multi-planar format e.g: ::CU_AD_FORMAT_NV12, then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Note that if the \p hArray has format ::CU_AD_FORMAT_NV12, then passing in 0 for \p planeIdx returns + * a CUDA array of the same size as \p hArray but with one channel and ::CU_AD_FORMAT_UNSIGNED_INT8 as its format. + * If 1 is passed for \p planeIdx, then the returned CUDA array has half the height and width + * of \p hArray with two channels and ::CU_AD_FORMAT_UNSIGNED_INT8 as its format. + * + * \param pPlaneArray - Returned CUDA array referenced by the \p planeIdx + * \param hArray - Multiplanar CUDA array + * \param planeIdx - Plane index + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::cuArrayCreate, + * ::cudaGetArrayPlane + */ +CUresult CUDAAPI cuArrayGetPlane(CUarray *pPlaneArray, CUarray hArray, unsigned int planeIdx); + +/** + * \brief Destroys a CUDA array + * + * Destroys the CUDA array \p hArray. + * + * \param hArray - Array to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ARRAY_IS_MAPPED, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaFreeArray + */ +CUresult CUDAAPI cuArrayDestroy(CUarray hArray); + +/** + * \brief Creates a 3D CUDA array + * + * Creates a CUDA array according to the ::CUDA_ARRAY3D_DESCRIPTOR structure + * \p pAllocateArray and returns a handle to the new CUDA array in \p *pHandle. + * The ::CUDA_ARRAY3D_DESCRIPTOR is defined as: + * + * \code + typedef struct { + unsigned int Width; + unsigned int Height; + unsigned int Depth; + CUarray_format Format; + unsigned int NumChannels; + unsigned int Flags; + } CUDA_ARRAY3D_DESCRIPTOR; + * \endcode + * where: + * + * - \p Width, \p Height, and \p Depth are the width, height, and depth of the + * CUDA array (in elements); the following types of CUDA arrays can be allocated: + * - A 1D array is allocated if \p Height and \p Depth extents are both zero. + * - A 2D array is allocated if only \p Depth extent is zero. + * - A 3D array is allocated if all three extents are non-zero. + * - A 1D layered CUDA array is allocated if only \p Height is zero and the + * ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 1D array. The number + * of layers is determined by the depth extent. + * - A 2D layered CUDA array is allocated if all three extents are non-zero and + * the ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 2D array. The number + * of layers is determined by the depth extent. + * - A cubemap CUDA array is allocated if all three extents are non-zero and the + * ::CUDA_ARRAY3D_CUBEMAP flag is set. \p Width must be equal to \p Height, and + * \p Depth must be six. A cubemap is a special type of 2D layered CUDA array, + * where the six layers represent the six faces of a cube. The order of the six + * layers in memory is the same as that listed in ::CUarray_cubemap_face. + * - A cubemap layered CUDA array is allocated if all three extents are non-zero, + * and both, ::CUDA_ARRAY3D_CUBEMAP and ::CUDA_ARRAY3D_LAYERED flags are set. + * \p Width must be equal to \p Height, and \p Depth must be a multiple of six. + * A cubemap layered CUDA array is a special type of 2D layered CUDA array that + * consists of a collection of cubemaps. The first six layers represent the first + * cubemap, the next six layers form the second cubemap, and so on. + * + * - ::Format specifies the format of the elements; ::CUarray_format is + * defined as: + * \code + typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, + CU_AD_FORMAT_SIGNED_INT8 = 0x08, + CU_AD_FORMAT_SIGNED_INT16 = 0x09, + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, + CU_AD_FORMAT_HALF = 0x10, + CU_AD_FORMAT_FLOAT = 0x20 + } CUarray_format; + * \endcode + * + * - \p NumChannels specifies the number of packed components per CUDA array + * element; it may be 1, 2, or 4; + * + * - ::Flags may be set to + * - ::CUDA_ARRAY3D_LAYERED to enable creation of layered CUDA arrays. If this flag is set, + * \p Depth specifies the number of layers, not the depth of a 3D array. + * - ::CUDA_ARRAY3D_SURFACE_LDST to enable surface references to be bound to the CUDA array. + * If this flag is not set, ::cuSurfRefSetArray will fail when attempting to bind the CUDA array + * to a surface reference. + * - ::CUDA_ARRAY3D_CUBEMAP to enable creation of cubemaps. If this flag is set, \p Width must be + * equal to \p Height, and \p Depth must be six. If the ::CUDA_ARRAY3D_LAYERED flag is also set, + * then \p Depth must be a multiple of six. + * - ::CUDA_ARRAY3D_TEXTURE_GATHER to indicate that the CUDA array will be used for texture gather. + * Texture gather can only be performed on 2D CUDA arrays. + * + * \p Width, \p Height and \p Depth must meet certain size requirements as listed in the following table. + * All values are specified in elements. Note that for brevity's sake, the full name of the device attribute + * is not specified. For ex., TEXTURE1D_WIDTH refers to the device attribute + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_WIDTH. + * + * Note that 2D CUDA arrays have different size requirements if the ::CUDA_ARRAY3D_TEXTURE_GATHER flag + * is set. \p Width and \p Height must not be greater than ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_WIDTH + * and ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_GATHER_HEIGHT respectively, in that case. + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
CUDA array typeValid extents that must always be met
{(width range in elements), (height range), + * (depth range)}
Valid extents with CUDA_ARRAY3D_SURFACE_LDST set
+ * {(width range in elements), (height range), (depth range)}
1D{ (1,TEXTURE1D_WIDTH), 0, 0 }{ (1,SURFACE1D_WIDTH), 0, 0 }
2D{ (1,TEXTURE2D_WIDTH), (1,TEXTURE2D_HEIGHT), 0 }{ (1,SURFACE2D_WIDTH), (1,SURFACE2D_HEIGHT), 0 }
3D{ (1,TEXTURE3D_WIDTH), (1,TEXTURE3D_HEIGHT), (1,TEXTURE3D_DEPTH) } + *
OR
{ (1,TEXTURE3D_WIDTH_ALTERNATE), (1,TEXTURE3D_HEIGHT_ALTERNATE), + * (1,TEXTURE3D_DEPTH_ALTERNATE) }
{ (1,SURFACE3D_WIDTH), (1,SURFACE3D_HEIGHT), + * (1,SURFACE3D_DEPTH) }
1D Layered{ (1,TEXTURE1D_LAYERED_WIDTH), 0, + * (1,TEXTURE1D_LAYERED_LAYERS) }{ (1,SURFACE1D_LAYERED_WIDTH), 0, + * (1,SURFACE1D_LAYERED_LAYERS) }
2D Layered{ (1,TEXTURE2D_LAYERED_WIDTH), (1,TEXTURE2D_LAYERED_HEIGHT), + * (1,TEXTURE2D_LAYERED_LAYERS) }{ (1,SURFACE2D_LAYERED_WIDTH), (1,SURFACE2D_LAYERED_HEIGHT), + * (1,SURFACE2D_LAYERED_LAYERS) }
Cubemap{ (1,TEXTURECUBEMAP_WIDTH), (1,TEXTURECUBEMAP_WIDTH), 6 }{ (1,SURFACECUBEMAP_WIDTH), + * (1,SURFACECUBEMAP_WIDTH), 6 }
Cubemap Layered{ (1,TEXTURECUBEMAP_LAYERED_WIDTH), (1,TEXTURECUBEMAP_LAYERED_WIDTH), + * (1,TEXTURECUBEMAP_LAYERED_LAYERS) }{ (1,SURFACECUBEMAP_LAYERED_WIDTH), (1,SURFACECUBEMAP_LAYERED_WIDTH), + * (1,SURFACECUBEMAP_LAYERED_LAYERS) }
+ * + * Here are examples of CUDA array descriptions: + * + * Description for a CUDA array of 2048 floats: + * \code + CUDA_ARRAY3D_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 2048; + desc.Height = 0; + desc.Depth = 0; + * \endcode + * + * Description for a 64 x 64 CUDA array of floats: + * \code + CUDA_ARRAY3D_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_FLOAT; + desc.NumChannels = 1; + desc.Width = 64; + desc.Height = 64; + desc.Depth = 0; + * \endcode + * + * Description for a \p width x \p height x \p depth CUDA array of 64-bit, + * 4x16-bit float16's: + * \code + CUDA_ARRAY3D_DESCRIPTOR desc; + desc.Format = CU_AD_FORMAT_HALF; + desc.NumChannels = 4; + desc.Width = width; + desc.Height = height; + desc.Depth = depth; + * \endcode + * + * \param pHandle - Returned array + * \param pAllocateArray - 3D array descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa ::cuArray3DGetDescriptor, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaMalloc3DArray + */ +CUresult CUDAAPI cuArray3DCreate(CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pAllocateArray); + +/** + * \brief Get a 3D CUDA array descriptor + * + * Returns in \p *pArrayDescriptor a descriptor containing information on the + * format and dimensions of the CUDA array \p hArray. It is useful for + * subroutines that have been passed a CUDA array, but need to know the CUDA + * array parameters for validation or other purposes. + * + * This function may be called on 1D and 2D arrays, in which case the \p Height + * and/or \p Depth members of the descriptor struct will be set to 0. + * + * \param pArrayDescriptor - Returned 3D array descriptor + * \param hArray - 3D array to get descriptor of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * \notefnerr + * + * \sa ::cuArray3DCreate, ::cuArrayCreate, + * ::cuArrayDestroy, ::cuArrayGetDescriptor, ::cuMemAlloc, ::cuMemAllocHost, + * ::cuMemAllocPitch, ::cuMemcpy2D, ::cuMemcpy2DAsync, ::cuMemcpy2DUnaligned, + * ::cuMemcpy3D, ::cuMemcpy3DAsync, ::cuMemcpyAtoA, ::cuMemcpyAtoD, + * ::cuMemcpyAtoH, ::cuMemcpyAtoHAsync, ::cuMemcpyDtoA, ::cuMemcpyDtoD, ::cuMemcpyDtoDAsync, + * ::cuMemcpyDtoH, ::cuMemcpyDtoHAsync, ::cuMemcpyHtoA, ::cuMemcpyHtoAAsync, + * ::cuMemcpyHtoD, ::cuMemcpyHtoDAsync, ::cuMemFree, ::cuMemFreeHost, + * ::cuMemGetAddressRange, ::cuMemGetInfo, ::cuMemHostAlloc, + * ::cuMemHostGetDevicePointer, ::cuMemsetD2D8, ::cuMemsetD2D16, + * ::cuMemsetD2D32, ::cuMemsetD8, ::cuMemsetD16, ::cuMemsetD32, + * ::cudaArrayGetInfo + */ +CUresult CUDAAPI cuArray3DGetDescriptor(CUDA_ARRAY3D_DESCRIPTOR *pArrayDescriptor, CUarray hArray); + +/** + * \brief Creates a CUDA mipmapped array + * + * Creates a CUDA mipmapped array according to the ::CUDA_ARRAY3D_DESCRIPTOR structure + * \p pMipmappedArrayDesc and returns a handle to the new CUDA mipmapped array in \p *pHandle. + * \p numMipmapLevels specifies the number of mipmap levels to be allocated. This value is + * clamped to the range [1, 1 + floor(log2(max(width, height, depth)))]. + * + * The ::CUDA_ARRAY3D_DESCRIPTOR is defined as: + * + * \code + typedef struct { + unsigned int Width; + unsigned int Height; + unsigned int Depth; + CUarray_format Format; + unsigned int NumChannels; + unsigned int Flags; + } CUDA_ARRAY3D_DESCRIPTOR; + * \endcode + * where: + * + * - \p Width, \p Height, and \p Depth are the width, height, and depth of the + * CUDA array (in elements); the following types of CUDA arrays can be allocated: + * - A 1D mipmapped array is allocated if \p Height and \p Depth extents are both zero. + * - A 2D mipmapped array is allocated if only \p Depth extent is zero. + * - A 3D mipmapped array is allocated if all three extents are non-zero. + * - A 1D layered CUDA mipmapped array is allocated if only \p Height is zero and the + * ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 1D array. The number + * of layers is determined by the depth extent. + * - A 2D layered CUDA mipmapped array is allocated if all three extents are non-zero and + * the ::CUDA_ARRAY3D_LAYERED flag is set. Each layer is a 2D array. The number + * of layers is determined by the depth extent. + * - A cubemap CUDA mipmapped array is allocated if all three extents are non-zero and the + * ::CUDA_ARRAY3D_CUBEMAP flag is set. \p Width must be equal to \p Height, and + * \p Depth must be six. A cubemap is a special type of 2D layered CUDA array, + * where the six layers represent the six faces of a cube. The order of the six + * layers in memory is the same as that listed in ::CUarray_cubemap_face. + * - A cubemap layered CUDA mipmapped array is allocated if all three extents are non-zero, + * and both, ::CUDA_ARRAY3D_CUBEMAP and ::CUDA_ARRAY3D_LAYERED flags are set. + * \p Width must be equal to \p Height, and \p Depth must be a multiple of six. + * A cubemap layered CUDA array is a special type of 2D layered CUDA array that + * consists of a collection of cubemaps. The first six layers represent the first + * cubemap, the next six layers form the second cubemap, and so on. + * + * - ::Format specifies the format of the elements; ::CUarray_format is + * defined as: + * \code + typedef enum CUarray_format_enum { + CU_AD_FORMAT_UNSIGNED_INT8 = 0x01, + CU_AD_FORMAT_UNSIGNED_INT16 = 0x02, + CU_AD_FORMAT_UNSIGNED_INT32 = 0x03, + CU_AD_FORMAT_SIGNED_INT8 = 0x08, + CU_AD_FORMAT_SIGNED_INT16 = 0x09, + CU_AD_FORMAT_SIGNED_INT32 = 0x0a, + CU_AD_FORMAT_HALF = 0x10, + CU_AD_FORMAT_FLOAT = 0x20 + } CUarray_format; + * \endcode + * + * - \p NumChannels specifies the number of packed components per CUDA array + * element; it may be 1, 2, or 4; + * + * - ::Flags may be set to + * - ::CUDA_ARRAY3D_LAYERED to enable creation of layered CUDA mipmapped arrays. If this flag is set, + * \p Depth specifies the number of layers, not the depth of a 3D array. + * - ::CUDA_ARRAY3D_SURFACE_LDST to enable surface references to be bound to individual mipmap levels of + * the CUDA mipmapped array. If this flag is not set, ::cuSurfRefSetArray will fail when attempting to + * bind a mipmap level of the CUDA mipmapped array to a surface reference. + * - ::CUDA_ARRAY3D_CUBEMAP to enable creation of mipmapped cubemaps. If this flag is set, \p Width must be + * equal to \p Height, and \p Depth must be six. If the ::CUDA_ARRAY3D_LAYERED flag is also set, + * then \p Depth must be a multiple of six. + * - ::CUDA_ARRAY3D_TEXTURE_GATHER to indicate that the CUDA mipmapped array will be used for texture gather. + * Texture gather can only be performed on 2D CUDA mipmapped arrays. + * + * \p Width, \p Height and \p Depth must meet certain size requirements as listed in the following table. + * All values are specified in elements. Note that for brevity's sake, the full name of the device attribute + * is not specified. For ex., TEXTURE1D_MIPMAPPED_WIDTH refers to the device attribute + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_MIPMAPPED_WIDTH. + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
CUDA array typeValid extents that must always be met
{(width range in elements), (height range), + * (depth range)}
Valid extents with CUDA_ARRAY3D_SURFACE_LDST set
+ * {(width range in elements), (height range), (depth range)}
1D{ (1,TEXTURE1D_MIPMAPPED_WIDTH), 0, 0 }{ (1,SURFACE1D_WIDTH), 0, 0 }
2D{ (1,TEXTURE2D_MIPMAPPED_WIDTH), (1,TEXTURE2D_MIPMAPPED_HEIGHT), 0 }{ (1,SURFACE2D_WIDTH), (1,SURFACE2D_HEIGHT), 0 }
3D{ (1,TEXTURE3D_WIDTH), (1,TEXTURE3D_HEIGHT), (1,TEXTURE3D_DEPTH) } + *
OR
{ (1,TEXTURE3D_WIDTH_ALTERNATE), (1,TEXTURE3D_HEIGHT_ALTERNATE), + * (1,TEXTURE3D_DEPTH_ALTERNATE) }
{ (1,SURFACE3D_WIDTH), (1,SURFACE3D_HEIGHT), + * (1,SURFACE3D_DEPTH) }
1D Layered{ (1,TEXTURE1D_LAYERED_WIDTH), 0, + * (1,TEXTURE1D_LAYERED_LAYERS) }{ (1,SURFACE1D_LAYERED_WIDTH), 0, + * (1,SURFACE1D_LAYERED_LAYERS) }
2D Layered{ (1,TEXTURE2D_LAYERED_WIDTH), (1,TEXTURE2D_LAYERED_HEIGHT), + * (1,TEXTURE2D_LAYERED_LAYERS) }{ (1,SURFACE2D_LAYERED_WIDTH), (1,SURFACE2D_LAYERED_HEIGHT), + * (1,SURFACE2D_LAYERED_LAYERS) }
Cubemap{ (1,TEXTURECUBEMAP_WIDTH), (1,TEXTURECUBEMAP_WIDTH), 6 }{ (1,SURFACECUBEMAP_WIDTH), + * (1,SURFACECUBEMAP_WIDTH), 6 }
Cubemap Layered{ (1,TEXTURECUBEMAP_LAYERED_WIDTH), (1,TEXTURECUBEMAP_LAYERED_WIDTH), + * (1,TEXTURECUBEMAP_LAYERED_LAYERS) }{ (1,SURFACECUBEMAP_LAYERED_WIDTH), (1,SURFACECUBEMAP_LAYERED_WIDTH), + * (1,SURFACECUBEMAP_LAYERED_LAYERS) }
+ * + * + * \param pHandle - Returned mipmapped array + * \param pMipmappedArrayDesc - mipmapped array descriptor + * \param numMipmapLevels - Number of mipmap levels + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cuMipmappedArrayDestroy, + * ::cuMipmappedArrayGetLevel, + * ::cuArrayCreate, + * ::cudaMallocMipmappedArray + */ +CUresult CUDAAPI cuMipmappedArrayCreate(CUmipmappedArray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR *pMipmappedArrayDesc, unsigned int numMipmapLevels); + +/** + * \brief Gets a mipmap level of a CUDA mipmapped array + * + * Returns in \p *pLevelArray a CUDA array that represents a single mipmap level + * of the CUDA mipmapped array \p hMipmappedArray. + * + * If \p level is greater than the maximum number of levels in this mipmapped array, + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param pLevelArray - Returned mipmap level CUDA array + * \param hMipmappedArray - CUDA mipmapped array + * \param level - Mipmap level + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::cuMipmappedArrayCreate, + * ::cuMipmappedArrayDestroy, + * ::cuArrayCreate, + * ::cudaGetMipmappedArrayLevel + */ +CUresult CUDAAPI cuMipmappedArrayGetLevel(CUarray *pLevelArray, CUmipmappedArray hMipmappedArray, unsigned int level); + +/** + * \brief Destroys a CUDA mipmapped array + * + * Destroys the CUDA mipmapped array \p hMipmappedArray. + * + * \param hMipmappedArray - Mipmapped array to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ARRAY_IS_MAPPED, + * ::CUDA_ERROR_CONTEXT_IS_DESTROYED + * \notefnerr + * + * \sa + * ::cuMipmappedArrayCreate, + * ::cuMipmappedArrayGetLevel, + * ::cuArrayCreate, + * ::cudaFreeMipmappedArray + */ +CUresult CUDAAPI cuMipmappedArrayDestroy(CUmipmappedArray hMipmappedArray); + +/** @} */ /* END CUDA_MEM */ + +/** + * \defgroup CUDA_VA Virtual Memory Management + * + * ___MANBRIEF___ virtual memory management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the virtual memory management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** +* \brief Allocate an address range reservation. +* +* Reserves a virtual address range based on the given parameters, giving +* the starting address of the range in \p ptr. This API requires a system that +* supports UVA. The size and address parameters must be a multiple of the +* host page size and the alignment must be a power of two or zero for default +* alignment. +* +* \param[out] ptr - Resulting pointer to start of virtual address range allocated +* \param[in] size - Size of the reserved virtual address range requested +* \param[in] alignment - Alignment of the reserved virtual address range requested +* \param[in] addr - Fixed starting address range requested +* \param[in] flags - Currently unused, must be zero +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_OUT_OF_MEMORY, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemAddressFree +*/ +CUresult CUDAAPI cuMemAddressReserve(CUdeviceptr *ptr, size_t size, size_t alignment, CUdeviceptr addr, unsigned long long flags); + +/** +* \brief Free an address range reservation. +* +* Frees a virtual address range reserved by cuMemAddressReserve. The size +* must match what was given to memAddressReserve and the ptr given must +* match what was returned from memAddressReserve. +* +* \param[in] ptr - Starting address of the virtual address range to free +* \param[in] size - Size of the virtual address region to free +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemAddressReserve +*/ +CUresult CUDAAPI cuMemAddressFree(CUdeviceptr ptr, size_t size); + +/** +* \brief Create a CUDA memory handle representing a memory allocation of a given size described by the given properties +* +* This creates a memory allocation on the target device specified through the +* \p prop strcuture. The created allocation will not have any device or host +* mappings. The generic memory \p handle for the allocation can be +* mapped to the address space of calling process via ::cuMemMap. This handle +* cannot be transmitted directly to other processes (see +* ::cuMemExportToShareableHandle). On Windows, the caller must also pass +* an LPSECURITYATTRIBUTE in \p prop to be associated with this handle which +* limits or allows access to this handle for a recepient process (see +* ::CUmemAllocationProp::win32HandleMetaData for more). The \p size of this +* allocation must be a multiple of the the value given via +* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM +* flag. +* If ::CUmemAllocationProp::allocFlags::usage contains ::CU_MEM_CREATE_USAGE_TILE_POOL flag then +* the memory allocation is intended only to be used as backing tile pool for sparse CUDA arrays +* and sparse CUDA mipmapped arrays. +* (see ::cuMemMapArrayAsync). +* +* \param[out] handle - Value of handle returned. All operations on this allocation are to be performed using this handle. +* \param[in] size - Size of the allocation requested +* \param[in] prop - Properties of the allocation to create. +* \param[in] flags - flags for future use, must be zero now. +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_OUT_OF_MEMORY, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* +* \sa ::cuMemRelease, ::cuMemExportToShareableHandle, ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemCreate(CUmemGenericAllocationHandle *handle, size_t size, const CUmemAllocationProp *prop, unsigned long long flags); + +/** +* \brief Release a memory handle representing a memory allocation which was previously allocated through cuMemCreate. +* +* Frees the memory that was allocated on a device through cuMemCreate. +* +* The memory allocation will be freed when all outstanding mappings to the memory +* are unmapped and when all outstanding references to the handle (including it's +* shareable counterparts) are also released. The generic memory handle can be +* freed when there are still outstanding mappings made with this handle. Each +* time a recepient process imports a shareable handle, it needs to pair it with +* ::cuMemRelease for the handle to be freed. If \p handle is not a valid handle +* the behavior is undefined. +* +* \param[in] handle Value of handle which was returned previously by cuMemCreate. +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* +* \sa ::cuMemCreate +*/ +CUresult CUDAAPI cuMemRelease(CUmemGenericAllocationHandle handle); + +/** +* \brief Maps an allocation handle to a reserved virtual address range. +* +* Maps bytes of memory represented by \p handle starting from byte \p offset to +* \p size to address range [\p addr, \p addr + \p size]. This range must be an +* address reservation previously reserved with ::cuMemAddressReserve, and +* \p offset + \p size must be less than the size of the memory allocation. +* Both \p ptr, \p size, and \p offset must be a multiple of the value given via +* ::cuMemGetAllocationGranularity with the ::CU_MEM_ALLOC_GRANULARITY_MINIMUM flag. +* +* Please note calling ::cuMemMap does not make the address accessible, +* the caller needs to update accessibility of a contiguous mapped VA +* range by calling ::cuMemSetAccess. +* +* Once a recipient process obtains a shareable memory handle +* from ::cuMemImportFromShareableHandle, the process must +* use ::cuMemMap to map the memory into its address ranges before +* setting accessibility with ::cuMemSetAccess. +* +* ::cuMemMap can only create mappings on VA range reservations +* that are not currently mapped. +* +* \param[in] ptr - Address where memory will be mapped. +* \param[in] size - Size of the memory mapping. +* \param[in] offset - Offset into the memory represented by +* - \p handle from which to start mapping +* - Note: currently must be zero. +* \param[in] handle - Handle to a shareable memory +* \param[in] flags - flags for future use, must be zero now. +* \return +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_OUT_OF_MEMORY, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* +* \sa ::cuMemUnmap, ::cuMemSetAccess, ::cuMemCreate, ::cuMemAddressReserve, ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemMap(CUdeviceptr ptr, size_t size, size_t offset, CUmemGenericAllocationHandle handle, unsigned long long flags); + +/** + * \brief Maps or unmaps subregions of sparse CUDA arrays and sparse CUDA mipmapped arrays + * + * Performs map or unmap operations on subregions of sparse CUDA arrays and sparse CUDA mipmapped arrays. + * Each operation is specified by a ::CUarrayMapInfo entry in the \p mapInfoList array of size \p count. + * The structure ::CUarrayMapInfo is defined as follow: + \code + typedef struct CUarrayMapInfo_st { + CUresourcetype resourceType; + union { + CUmipmappedArray mipmap; + CUarray array; + } resource; + + CUarraySparseSubresourceType subresourceType; + union { + struct { + unsigned int level; + unsigned int layer; + unsigned int offsetX; + unsigned int offsetY; + unsigned int offsetZ; + unsigned int extentWidth; + unsigned int extentHeight; + unsigned int extentDepth; + } sparseLevel; + struct { + unsigned int layer; + unsigned long long offset; + unsigned long long size; + } miptail; + } subresource; + + CUmemOperationType memOperationType; + + CUmemHandleType memHandleType; + union { + CUmemGenericAllocationHandle memHandle; + } memHandle; + + unsigned long long offset; + unsigned int deviceBitMask; + unsigned int flags; + unsigned int reserved[2]; + } CUarrayMapInfo; + \endcode + * + * where ::CUarrayMapInfo::resourceType specifies the type of resource to be operated on. + * If ::CUarrayMapInfo::resourceType is set to ::CUresourcetype::CU_RESOURCE_TYPE_ARRAY then + * ::CUarrayMapInfo::resource::array must be set to a valid sparse CUDA array handle. + * The CUDA array must be either a 2D, 2D layered or 3D CUDA array and must have been allocated using + * ::cuArrayCreate or ::cuArray3DCreate with the flag ::CUDA_ARRAY3D_SPARSE + + * or ::CUDA_ARRAY3D_DEFERRED_MAPPING. + + * For CUDA arrays obtained using ::cuMipmappedArrayGetLevel, ::CUDA_ERROR_INVALID_VALUE will be returned. + * If ::CUarrayMapInfo::resourceType is set to ::CUresourcetype::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY + * then ::CUarrayMapInfo::resource::mipmap must be set to a valid sparse CUDA mipmapped array handle. + * The CUDA mipmapped array must be either a 2D, 2D layered or 3D CUDA mipmapped array and must have been + * allocated using ::cuMipmappedArrayCreate with the flag ::CUDA_ARRAY3D_SPARSE + + * or ::CUDA_ARRAY3D_DEFERRED_MAPPING. + + * + * ::CUarrayMapInfo::subresourceType specifies the type of subresource within the resource. + * ::CUarraySparseSubresourceType_enum is defined as: + \code + typedef enum CUarraySparseSubresourceType_enum { + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL = 0, + CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL = 1 + } CUarraySparseSubresourceType; + \endcode + * + * where ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL indicates a + * sparse-miplevel which spans at least one tile in every dimension. The remaining miplevels which + * are too small to span at least one tile in any dimension constitute the mip tail region as indicated by + * ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL subresource type. + * + * If ::CUarrayMapInfo::subresourceType is set to ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_SPARSE_LEVEL + * then ::CUarrayMapInfo::subresource::sparseLevel struct must contain valid array subregion offsets and extents. + * The ::CUarrayMapInfo::subresource::sparseLevel::offsetX, ::CUarrayMapInfo::subresource::sparseLevel::offsetY + * and ::CUarrayMapInfo::subresource::sparseLevel::offsetZ must specify valid X, Y and Z offsets respectively. + * The ::CUarrayMapInfo::subresource::sparseLevel::extentWidth, ::CUarrayMapInfo::subresource::sparseLevel::extentHeight + * and ::CUarrayMapInfo::subresource::sparseLevel::extentDepth must specify valid width, height and depth extents respectively. + * These offsets and extents must be aligned to the corresponding tile dimension. + * For CUDA mipmapped arrays ::CUarrayMapInfo::subresource::sparseLevel::level must specify a valid mip level index. Otherwise, + * must be zero. + * For layered CUDA arrays and layered CUDA mipmapped arrays ::CUarrayMapInfo::subresource::sparseLevel::layer must specify a valid layer index. Otherwise, + * must be zero. + * ::CUarrayMapInfo::subresource::sparseLevel::offsetZ must be zero and ::CUarrayMapInfo::subresource::sparseLevel::extentDepth + * must be set to 1 for 2D and 2D layered CUDA arrays and CUDA mipmapped arrays. + * Tile extents can be obtained by calling ::cuArrayGetSparseProperties and ::cuMipmappedArrayGetSparseProperties + * + * If ::CUarrayMapInfo::subresourceType is set to ::CUarraySparseSubresourceType::CU_ARRAY_SPARSE_SUBRESOURCE_TYPE_MIPTAIL + * then ::CUarrayMapInfo::subresource::miptail struct must contain valid mip tail offset in + * ::CUarrayMapInfo::subresource::miptail::offset and size in ::CUarrayMapInfo::subresource::miptail::size. + * Both, mip tail offset and mip tail size must be aligned to the tile size. + * For layered CUDA mipmapped arrays which don't have the flag ::CU_ARRAY_SPARSE_PROPERTIES_SINGLE_MIPTAIL set in ::CUDA_ARRAY_SPARSE_PROPERTIES::flags + * as returned by ::cuMipmappedArrayGetSparseProperties, ::CUarrayMapInfo::subresource::miptail::layer must specify a valid layer index. + * Otherwise, must be zero. + * + + * If ::CUarrayMapInfo::resource::array or ::CUarrayMapInfo::resource::mipmap was created with ::CUDA_ARRAY3D_DEFERRED_MAPPING + * flag set the ::CUarrayMapInfo::subresourceType and the contents of ::CUarrayMapInfo::subresource will be ignored. + * + + * ::CUarrayMapInfo::memOperationType specifies the type of operation. ::CUmemOperationType is defined as: + \code + typedef enum CUmemOperationType_enum { + CU_MEM_OPERATION_TYPE_MAP = 1, + CU_MEM_OPERATION_TYPE_UNMAP = 2 + } CUmemOperationType; + \endcode + * If ::CUarrayMapInfo::memOperationType is set to ::CUmemOperationType::CU_MEM_OPERATION_TYPE_MAP then the subresource + * will be mapped onto the tile pool memory specified by ::CUarrayMapInfo::memHandle at offset ::CUarrayMapInfo::offset. + * The tile pool allocation has to be created by specifying the ::CU_MEM_CREATE_USAGE_TILE_POOL flag when calling ::cuMemCreate. Also, + * ::CUarrayMapInfo::memHandleType must be set to ::CUmemHandleType::CU_MEM_HANDLE_TYPE_GENERIC. + * + * If ::CUarrayMapInfo::memOperationType is set to ::CUmemOperationType::CU_MEM_OPERATION_TYPE_UNMAP then an unmapping operation + * is performed. ::CUarrayMapInfo::memHandle must be NULL. + * + * ::CUarrayMapInfo::deviceBitMask specifies the list of devices that must map or unmap physical memory. + * Currently, this mask must have exactly one bit set, and the corresponding device must match the device associated with the stream. + * If ::CUarrayMapInfo::memOperationType is set to ::CUmemOperationType::CU_MEM_OPERATION_TYPE_MAP, the device must also match + * the device associated with the tile pool memory allocation as specified by ::CUarrayMapInfo::memHandle. + * + * ::CUarrayMapInfo::flags and ::CUarrayMapInfo::reserved[] are unused and must be set to zero. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * + * \param[in] mapInfoList - List of ::CUarrayMapInfo + * \param[in] count - Count of ::CUarrayMapInfo in \p mapInfoList + * \param[in] hStream - Stream identifier for the stream to use for map or unmap operations + * + * \sa ::cuMipmappedArrayCreate, ::cuArrayCreate, ::cuArray3DCreate, ::cuMemCreate, ::cuArrayGetSparseProperties, ::cuMipmappedArrayGetSparseProperties + */ +CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, unsigned int count, CUstream hStream); + +/** +* \brief Unmap the backing memory of a given address range. +* +* The range must be the entire contiguous address range that was mapped to. In +* other words, ::cuMemUnmap cannot unmap a sub-range of an address range mapped +* by ::cuMemCreate / ::cuMemMap. Any backing memory allocations will be freed +* if there are no existing mappings and there are no unreleased memory handles. +* +* When ::cuMemUnmap returns successfully the address range is converted to an +* address reservation and can be used for a future calls to ::cuMemMap. Any new +* mapping to this virtual address will need to have access granted through +* ::cuMemSetAccess, as all mappings start with no accessibility setup. +* +* \param[in] ptr - Starting address for the virtual address range to unmap +* \param[in] size - Size of the virtual address range to unmap +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* \note_sync +* +* \sa ::cuMemCreate, ::cuMemAddressReserve +*/ +CUresult CUDAAPI cuMemUnmap(CUdeviceptr ptr, size_t size); + +/** +* \brief Set the access flags for each location specified in \p desc for the given virtual address range +* +* Given the virtual address range via \p ptr and \p size, and the locations +* in the array given by \p desc and \p count, set the access flags for the +* target locations. The range must be a fully mapped address range +* containing all allocations created by ::cuMemMap / ::cuMemCreate. +* +* \param[in] ptr - Starting address for the virtual address range +* \param[in] size - Length of the virtual address range +* \param[in] desc - Array of ::CUmemAccessDesc that describe how to change the +* - mapping for each location specified +* \param[in] count - Number of ::CUmemAccessDesc in \p desc +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_NOT_SUPPORTED +* \notefnerr +* \note_sync +* +* \sa ::cuMemSetAccess, ::cuMemCreate, :cuMemMap +*/ +CUresult CUDAAPI cuMemSetAccess(CUdeviceptr ptr, size_t size, const CUmemAccessDesc *desc, size_t count); + +/** +* \brief Get the access \p flags set for the given \p location and \p ptr +* +* \param[out] flags - Flags set for this location +* \param[in] location - Location in which to check the flags for +* \param[in] ptr - Address in which to check the access flags for +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_INVALID_DEVICE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemSetAccess +*/ +CUresult CUDAAPI cuMemGetAccess(unsigned long long *flags, const CUmemLocation *location, CUdeviceptr ptr); + +/** +* \brief Exports an allocation to a requested shareable handle type +* +* Given a CUDA memory handle, create a shareable memory +* allocation handle that can be used to share the memory with other +* processes. The recipient process can convert the shareable handle back into a +* CUDA memory handle using ::cuMemImportFromShareableHandle and map +* it with ::cuMemMap. The implementation of what this handle is and how it +* can be transferred is defined by the requested handle type in \p handleType +* +* Once all shareable handles are closed and the allocation is released, the allocated +* memory referenced will be released back to the OS and uses of the CUDA handle afterward +* will lead to undefined behavior. +* +* This API can also be used in conjunction with other APIs (e.g. Vulkan, OpenGL) +* that support importing memory from the shareable type +* +* \param[out] shareableHandle - Pointer to the location in which to store the requested handle type +* \param[in] handle - CUDA handle for the memory allocation +* \param[in] handleType - Type of shareable handle requested (defines type and size of the \p shareableHandle output parameter) +* \param[in] flags - Reserved, must be zero +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemExportToShareableHandle(void *shareableHandle, CUmemGenericAllocationHandle handle, CUmemAllocationHandleType handleType, unsigned long long flags); + +/** +* \brief Imports an allocation from a requested shareable handle type. +* +* If the current process cannot support the memory described by this shareable +* handle, this API will error as CUDA_ERROR_NOT_SUPPORTED. +* +* \note Importing shareable handles exported from some graphics APIs(VUlkan, OpenGL, etc) +* created on devices under an SLI group may not be supported, and thus this API will +* return CUDA_ERROR_NOT_SUPPORTED. +* There is no guarantee that the contents of \p handle will be the same CUDA memory handle +* for the same given OS shareable handle, or the same underlying allocation. +* +* \param[out] handle - CUDA Memory handle for the memory allocation. +* \param[in] osHandle - Shareable Handle representing the memory allocation that is to be imported. +* \param[in] shHandleType - handle type of the exported handle ::CUmemAllocationHandleType. +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemExportToShareableHandle, ::cuMemMap, ::cuMemRelease +*/ +CUresult CUDAAPI cuMemImportFromShareableHandle(CUmemGenericAllocationHandle *handle, void *osHandle, CUmemAllocationHandleType shHandleType); + +/** +* \brief Calculates either the minimal or recommended granularity +* +* Calculates either the minimal or recommended granularity +* for a given allocation specification and returns it in granularity. This +* granularity can be used as a multiple for alignment, size, or address mapping. +* +* \param[out] granularity Returned granularity. +* \param[in] prop Property for which to determine the granularity for +* \param[in] option Determines which granularity to return +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemCreate, ::cuMemMap +*/ +CUresult CUDAAPI cuMemGetAllocationGranularity(size_t *granularity, const CUmemAllocationProp *prop, CUmemAllocationGranularity_flags option); + +/** +* \brief Retrieve the contents of the property structure defining properties for this handle +* +* \param[out] prop - Pointer to a properties structure which will hold the information about this handle +* \param[in] handle - Handle which to perform the query on +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemCreate, ::cuMemImportFromShareableHandle +*/ +CUresult CUDAAPI cuMemGetAllocationPropertiesFromHandle(CUmemAllocationProp *prop, CUmemGenericAllocationHandle handle); + +/** +* \brief Given an address \p addr, returns the allocation handle of the backing memory allocation. +* +* The handle is guaranteed to be the same handle value used to map the memory. If the address +* requested is not mapped, the function will fail. The returned handle must be released with +* corresponding number of calls to ::cuMemRelease. +* +* \note The address \p addr, can be any address in a range previously mapped +* by ::cuMemMap, and not necessarily the start address. +* +* \param[out] handle CUDA Memory handle for the backing memory allocation. +* \param[in] addr Memory address to query, that has been mapped previously. +* \returns +* ::CUDA_SUCCESS, +* ::CUDA_ERROR_INVALID_VALUE, +* ::CUDA_ERROR_NOT_INITIALIZED, +* ::CUDA_ERROR_DEINITIALIZED, +* ::CUDA_ERROR_NOT_PERMITTED, +* ::CUDA_ERROR_NOT_SUPPORTED +* +* \sa ::cuMemCreate, ::cuMemRelease, ::cuMemMap +*/ +CUresult CUDAAPI cuMemRetainAllocationHandle(CUmemGenericAllocationHandle *handle, void *addr); + +/** @} */ /* END CUDA_VA */ + +/** + * \defgroup CUDA_MALLOC_ASYNC Stream Ordered Memory Allocator + * + * ___MANBRIEF___ Functions for performing allocation and free operations in stream order. + * Functions for controlling the behavior of the underlying allocator. + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream ordered memory allocator exposed by the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_MALLOC_ASYNC_overview overview + * + * The asynchronous allocator allows the user to allocate and free in stream order. + * All asynchronous accesses of the allocation must happen between + * the stream executions of the allocation and the free. If the memory is accessed + * outside of the promised stream order, a use before allocation / use after free error + * will cause undefined behavior. + * + * The allocator is free to reallocate the memory as long as it can guarantee + * that compliant memory accesses will not overlap temporally. + * The allocator may refer to internal stream ordering as well as inter-stream dependencies + * (such as CUDA events and null stream dependencies) when establishing the temporal guarantee. + * The allocator may also insert inter-stream dependencies to establish the temporal guarantee. + * + * \section CUDA_MALLOC_ASYNC_support Supported Platforms + * + * Whether or not a device supports the integrated stream ordered memory allocator + * may be queried by calling ::cuDeviceGetAttribute() with the device attribute + * ::CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED + */ + +/** + * \brief Frees memory with stream ordered semantics + * + * Inserts a free operation into \p hStream. + * The allocation must not be accessed after stream execution reaches the free. + * After this API returns, accessing the memory from any subsequent work launched on the GPU + * or querying its pointer attributes results in undefined behavior. + * + * \note During stream capture, this function results in the creation of a free node and + * must therefore be passed the address of a graph allocation. + * + * \param dptr - memory to free + * \param hStream - The stream establishing the stream ordering contract. + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current context), + * ::CUDA_ERROR_NOT_SUPPORTED + */ +CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); + +/** + * \brief Allocates memory with stream ordered semantics + * + * Inserts an allocation operation into \p hStream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the the allocation operation completes. + * The allocation comes from the memory pool current to the stream's device. + * + * \note The default memory pool of a device contains device memory from that device. + * \note Basic stream ordering allows future work submitted into the same stream to use the allocation. + * Stream query, stream synchronize, and CUDA events can be used to guarantee that the allocation + * operation completes before work submitted in a separate stream runs. + * \note During stream capture, this function results in the creation of an allocation node. In this case, + * the allocation is owned by the graph instead of the memory pool. The memory pool's properties + * are used to set the node's creation parameters. + * + * \param[out] dptr - Returned device pointer + * \param[in] bytesize - Number of bytes to allocate + * \param[in] hStream - The stream establishing the stream ordering contract and the memory pool to allocate from + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current context), + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemAllocFromPoolAsync, ::cuMemFreeAsync, ::cuDeviceSetMemPool, + * ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolSetAccess, ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemAllocAsync(CUdeviceptr *dptr, size_t bytesize, CUstream hStream); + +/** + * \brief Tries to release memory back to the OS + * + * Releases memory back to the OS until the pool contains fewer than minBytesToKeep + * reserved bytes, or there is no more memory that the allocator can safely release. + * The allocator cannot release OS allocations that back outstanding asynchronous allocations. + * The OS allocations may happen at different granularity from the user allocations. + * + * \note: Allocations that have not been freed count as outstanding. + * \note: Allocations that have been asynchronously freed but whose completion has + * not been observed on the host (eg. by a synchronize) can count as outstanding. + * + * \param[in] pool - The memory pool to trim + * \param[in] minBytesToKeep - If the pool has less than minBytesToKeep reserved, + * the TrimTo operation is a no-op. Otherwise the pool will be guaranteed to have + * at least minBytesToKeep bytes reserved after the operation. + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolTrimTo(CUmemoryPool pool, size_t minBytesToKeep); + +/** + * \brief Sets attributes of a memory pool + * + * Supported attributes are: + * - ::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + * - ::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to use memory asynchronously freed + * in another stream as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC: (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by ::cuMemFreeAsync (default enabled). + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: (value type = cuuint64_t) + * Reset the high watermark that tracks the amount of backing memory that was + * allocated for the memory pool. It is illegal to set this attribute to a non-zero value. + * - ::CU_MEMPOOL_ATTR_USED_MEM_HIGH: (value type = cuuint64_t) + * Reset the high watermark that tracks the amount of used memory that was + * allocated for the memory pool. + * + * \param[in] pool - The memory pool to modify + * \param[in] attr - The attribute to modify + * \param[in] value - Pointer to the value to assign + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolSetAttribute(CUmemoryPool pool, CUmemPool_attribute attr, void *value); + +/** + * \brief Gets attributes of a memory pool + * + * Supported attributes are: + * - ::CU_MEMPOOL_ATTR_RELEASE_THRESHOLD: (value type = cuuint64_t) + * Amount of reserved memory in bytes to hold onto before trying + * to release memory back to the OS. When more than the release + * threshold bytes of memory are held by the memory pool, the + * allocator will try to release memory back to the OS on the + * next call to stream, event or context synchronize. (default 0) + * - ::CU_MEMPOOL_ATTR_REUSE_FOLLOW_EVENT_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to use memory asynchronously freed + * in another stream as long as a stream ordering dependency + * of the allocating stream on the free action exists. + * Cuda events and null stream interactions can create the required + * stream ordered dependencies. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC: (value type = int) + * Allow reuse of already completed frees when there is no dependency + * between the free and allocation. (default enabled) + * - ::CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES: (value type = int) + * Allow ::cuMemAllocAsync to insert new stream dependencies + * in order to establish the stream ordering required to reuse + * a piece of memory released by ::cuMemFreeAsync (default enabled). + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT: (value type = cuuint64_t) + * Amount of backing memory currently allocated for the mempool + * - ::CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH: (value type = cuuint64_t) + * High watermark of backing memory allocated for the mempool since the + * last time it was reset. + * - ::CU_MEMPOOL_ATTR_USED_MEM_CURRENT: (value type = cuuint64_t) + * Amount of memory from the pool that is currently in use by the application. + * - ::CU_MEMPOOL_ATTR_USED_MEM_HIGH: (value type = cuuint64_t) + * High watermark of the amount of memory from the pool that was in use by the application. + * + * \param[in] pool - The memory pool to get attributes of + * \param[in] attr - The attribute to get + * \param[out] value - Retrieved value + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolGetAttribute(CUmemoryPool pool, CUmemPool_attribute attr, void *value); + +/** + * \brief Controls visibility of pools between devices + * + * \param[in] pool - The pool being modified + * \param[in] map - Array of access descriptors. Each descriptor instructs the access to enable for a single gpu. + * \param[in] count - Number of descriptors in the map array. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolSetAccess(CUmemoryPool pool, const CUmemAccessDesc *map, size_t count); + +/** + * \brief Returns the accessibility of a pool from a device + * + * Returns the accessibility of the pool's memory from the specified location. + * + * \param[out] flags - the accessibility of the pool from the specified location + * \param[in] memPool - the pool being queried + * \param[in] location - the location accessing the pool + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolGetAccess(CUmemAccess_flags *flags, CUmemoryPool memPool, CUmemLocation *location); + +/** + * \brief Creates a memory pool + * + * Creates a CUDA memory pool and returns the handle in \p pool. The \p poolProps determines + * the properties of the pool such as the backing device and IPC capabilities. + * + * By default, the pool's memory will be accessible from the device it is allocated on. + * + * \note Specifying CU_MEM_HANDLE_TYPE_NONE creates a memory pool that will not support IPC. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY, + * ::CUDA_ERROR_NOT_SUPPORTED + * + * \sa ::cuDeviceSetMemPool, ::cuDeviceGetMemPool, ::cuDeviceGetDefaultMemPool, + * ::cuMemAllocFromPoolAsync, ::cuMemPoolExportToShareableHandle + */ +CUresult CUDAAPI cuMemPoolCreate(CUmemoryPool *pool, const CUmemPoolProps *poolProps); + +/** + * \brief Destroys the specified memory pool + * + * If any pointers obtained from this pool haven't been freed or + * the pool has free operations that haven't completed + * when ::cuMemPoolDestroy is invoked, the function will return immediately and the + * resources associated with the pool will be released automatically + * once there are no more outstanding allocations. + * + * Destroying the current mempool of a device sets the default mempool of + * that device as the current mempool for that device. + * + * \note A device's default memory pool cannot be destroyed. + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuMemFreeAsync, ::cuDeviceSetMemPool, ::cuDeviceGetMemPool, + * ::cuDeviceGetDefaultMemPool, ::cuMemPoolCreate + */ +CUresult CUDAAPI cuMemPoolDestroy(CUmemoryPool pool); + +/** + * \brief Allocates memory from a specified pool with stream ordered semantics. + * + * Inserts an allocation operation into \p hStream. + * A pointer to the allocated memory is returned immediately in *dptr. + * The allocation must not be accessed until the the allocation operation completes. + * The allocation comes from the specified memory pool. + * + * \note + * - The specified memory pool may be from a device different than that of the specified \p hStream. + * + * - Basic stream ordering allows future work submitted into the same stream to use the allocation. + * Stream query, stream synchronize, and CUDA events can be used to guarantee that the allocation + * operation completes before work submitted in a separate stream runs. + * + * \note During stream capture, this function results in the creation of an allocation node. In this case, + * the allocation is owned by the graph instead of the memory pool. The memory pool's properties + * are used to set the node's creation parameters. + * + * \param[out] dptr - Returned device pointer + * \param[in] bytesize - Number of bytes to allocate + * \param[in] pool - The pool to allocate from + * \param[in] hStream - The stream establishing the stream ordering semantic + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT (default stream specified with no current context), + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemAllocAsync, ::cuMemFreeAsync, ::cuDeviceGetDefaultMemPool, + * ::cuDeviceGetMemPool, ::cuMemPoolCreate, ::cuMemPoolSetAccess, + * ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, CUmemoryPool pool, CUstream hStream); + +/** + * \brief Exports a memory pool to the requested handle type. + * + * Given an IPC capable mempool, create an OS handle to share the pool with another process. + * A recipient process can convert the shareable handle into a mempool with ::cuMemPoolImportFromShareableHandle. + * Individual pointers can then be shared with the ::cuMemPoolExportPointer and ::cuMemPoolImportPointer APIs. + * The implementation of what the shareable handle is and how it can be transferred is defined by the requested + * handle type. + * + * \note: To create an IPC capable mempool, create a mempool with a CUmemAllocationHandleType other than CU_MEM_HANDLE_TYPE_NONE. + * + * \param[out] handle_out - Returned OS handle + * \param[in] pool - pool to export + * \param[in] handleType - the type of handle to create + * \param[in] flags - must be 0 + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolImportFromShareableHandle, ::cuMemPoolExportPointer, + * ::cuMemPoolImportPointer, ::cuMemAllocAsync, ::cuMemFreeAsync, + * ::cuDeviceGetDefaultMemPool, ::cuDeviceGetMemPool, ::cuMemPoolCreate, + * ::cuMemPoolSetAccess, ::cuMemPoolSetAttribute + */ +CUresult CUDAAPI cuMemPoolExportToShareableHandle(void *handle_out, CUmemoryPool pool, CUmemAllocationHandleType handleType, unsigned long long flags); + +/** + * \brief imports a memory pool from a shared handle. + * + * Specific allocations can be imported from the imported pool with cuMemPoolImportPointer. + * + * \note Imported memory pools do not support creating new allocations. + * As such imported memory pools may not be used in cuDeviceSetMemPool + * or ::cuMemAllocFromPoolAsync calls. + * + * \param[out] pool_out - Returned memory pool + * \param[in] handle - OS handle of the pool to open + * \param[in] handleType - The type of handle being imported + * \param[in] flags - must be 0 + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolExportPointer, ::cuMemPoolImportPointer + */ +CUresult CUDAAPI cuMemPoolImportFromShareableHandle( + CUmemoryPool *pool_out, + void *handle, + CUmemAllocationHandleType handleType, + unsigned long long flags); + +/** + * \brief Export data to share a memory pool allocation between processes. + * + * Constructs \p shareData_out for sharing a specific allocation from an already shared memory pool. + * The recipient process can import the allocation with the ::cuMemPoolImportPointer api. + * The data is not a handle and may be shared through any IPC mechanism. + * + * \param[out] shareData_out - Returned export data + * \param[in] ptr - pointer to memory being exported + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolImportFromShareableHandle, ::cuMemPoolImportPointer + */ +CUresult CUDAAPI cuMemPoolExportPointer(CUmemPoolPtrExportData *shareData_out, CUdeviceptr ptr); + +/** + * \brief Import a memory pool allocation from another process. + * + * Returns in \p ptr_out a pointer to the imported memory. + * The imported memory must not be accessed before the allocation operation completes + * in the exporting process. The imported memory must be freed from all importing processes before + * being freed in the exporting process. The pointer may be freed with cuMemFree + * or cuMemFreeAsync. If cuMemFreeAsync is used, the free must be completed + * on the importing process before the free operation on the exporting process. + * + * \note The cuMemFreeAsync api may be used in the exporting process before + * the cuMemFreeAsync operation completes in its stream as long as the + * cuMemFreeAsync in the exporting process specifies a stream with + * a stream dependency on the importing process's cuMemFreeAsync. + * + * \param[out] ptr_out - pointer to imported memory + * \param[in] pool - pool from which to import + * \param[in] shareData - data specifying the memory to import + * + * \returns + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_OUT_OF_MEMORY + * + * \sa ::cuMemPoolExportToShareableHandle, ::cuMemPoolImportFromShareableHandle, ::cuMemPoolExportPointer + */ +CUresult CUDAAPI cuMemPoolImportPointer(CUdeviceptr *ptr_out, CUmemoryPool pool, CUmemPoolPtrExportData *shareData); + +/** @} */ /* END CUDA_MALLOC_ASYNC */ + +/** + * \defgroup CUDA_UNIFIED Unified Addressing + * + * ___MANBRIEF___ unified addressing functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the unified addressing functions of the + * low-level CUDA driver application programming interface. + * + * @{ + * + * \section CUDA_UNIFIED_overview Overview + * + * CUDA devices can share a unified address space with the host. + * For these devices there is no distinction between a device + * pointer and a host pointer -- the same pointer value may be + * used to access memory from the host program and from a kernel + * running on the device (with exceptions enumerated below). + * + * \section CUDA_UNIFIED_support Supported Platforms + * + * Whether or not a device supports unified addressing may be + * queried by calling ::cuDeviceGetAttribute() with the device + * attribute ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING. + * + * Unified addressing is automatically enabled in 64-bit processes + * + * \section CUDA_UNIFIED_lookup Looking Up Information from Pointer Values + * + * It is possible to look up information about the memory which backs a + * pointer value. For instance, one may want to know if a pointer points + * to host or device memory. As another example, in the case of device + * memory, one may want to know on which CUDA device the memory + * resides. These properties may be queried using the function + * ::cuPointerGetAttribute() + * + * Since pointers are unique, it is not necessary to specify information + * about the pointers specified to the various copy functions in the + * CUDA API. The function ::cuMemcpy() may be used to perform a copy + * between two pointers, ignoring whether they point to host or device + * memory (making ::cuMemcpyHtoD(), ::cuMemcpyDtoD(), and ::cuMemcpyDtoH() + * unnecessary for devices supporting unified addressing). For + * multidimensional copies, the memory type ::CU_MEMORYTYPE_UNIFIED may be + * used to specify that the CUDA driver should infer the location of the + * pointer from its value. + * + * \section CUDA_UNIFIED_automaphost Automatic Mapping of Host Allocated Host Memory + * + * All host memory allocated in all contexts using ::cuMemAllocHost() and + * ::cuMemHostAlloc() is always directly accessible from all contexts on + * all devices that support unified addressing. This is the case regardless + * of whether or not the flags ::CU_MEMHOSTALLOC_PORTABLE and + * ::CU_MEMHOSTALLOC_DEVICEMAP are specified. + * + * The pointer value through which allocated host memory may be accessed + * in kernels on all devices that support unified addressing is the same + * as the pointer value through which that memory is accessed on the host, + * so it is not necessary to call ::cuMemHostGetDevicePointer() to get the device + * pointer for these allocations. + * + * Note that this is not the case for memory allocated using the flag + * ::CU_MEMHOSTALLOC_WRITECOMBINED, as discussed below. + * + * \section CUDA_UNIFIED_autopeerregister Automatic Registration of Peer Memory + * + * Upon enabling direct access from a context that supports unified addressing + * to another peer context that supports unified addressing using + * ::cuCtxEnablePeerAccess() all memory allocated in the peer context using + * ::cuMemAlloc() and ::cuMemAllocPitch() will immediately be accessible + * by the current context. The device pointer value through + * which any peer memory may be accessed in the current context + * is the same pointer value through which that memory may be + * accessed in the peer context. + * + * \section CUDA_UNIFIED_exceptions Exceptions, Disjoint Addressing + * + * Not all memory may be accessed on devices through the same pointer + * value through which they are accessed on the host. These exceptions + * are host memory registered using ::cuMemHostRegister() and host memory + * allocated using the flag ::CU_MEMHOSTALLOC_WRITECOMBINED. For these + * exceptions, there exists a distinct host and device address for the + * memory. The device address is guaranteed to not overlap any valid host + * pointer range and is guaranteed to have the same value across all + * contexts that support unified addressing. + * + * This device address may be queried using ::cuMemHostGetDevicePointer() + * when a context using unified addressing is current. Either the host + * or the unified device pointer value may be used to refer to this memory + * through ::cuMemcpy() and similar functions using the + * ::CU_MEMORYTYPE_UNIFIED memory type. + * + */ + +/** + * \brief Returns information about a pointer + * + * The supported attributes are: + * + * - ::CU_POINTER_ATTRIBUTE_CONTEXT: + * + * Returns in \p *data the ::CUcontext in which \p ptr was allocated or + * registered. + * The type of \p data must be ::CUcontext *. + * + * If \p ptr was not allocated by, mapped by, or registered with + * a ::CUcontext which uses unified virtual addressing then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * - ::CU_POINTER_ATTRIBUTE_MEMORY_TYPE: + * + * Returns in \p *data the physical memory type of the memory that + * \p ptr addresses as a ::CUmemorytype enumerated value. + * The type of \p data must be unsigned int. + * + * If \p ptr addresses device memory then \p *data is set to + * ::CU_MEMORYTYPE_DEVICE. The particular ::CUdevice on which the + * memory resides is the ::CUdevice of the ::CUcontext returned by the + * ::CU_POINTER_ATTRIBUTE_CONTEXT attribute of \p ptr. + * + * If \p ptr addresses host memory then \p *data is set to + * ::CU_MEMORYTYPE_HOST. + * + * If \p ptr was not allocated by, mapped by, or registered with + * a ::CUcontext which uses unified virtual addressing then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * If the current ::CUcontext does not support unified virtual + * addressing then ::CUDA_ERROR_INVALID_CONTEXT is returned. + * + * - ::CU_POINTER_ATTRIBUTE_DEVICE_POINTER: + * + * Returns in \p *data the device pointer value through which + * \p ptr may be accessed by kernels running in the current + * ::CUcontext. + * The type of \p data must be CUdeviceptr *. + * + * If there exists no device pointer value through which + * kernels running in the current ::CUcontext may access + * \p ptr then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * If there is no current ::CUcontext then + * ::CUDA_ERROR_INVALID_CONTEXT is returned. + * + * Except in the exceptional disjoint addressing cases discussed + * below, the value returned in \p *data will equal the input + * value \p ptr. + * + * - ::CU_POINTER_ATTRIBUTE_HOST_POINTER: + * + * Returns in \p *data the host pointer value through which + * \p ptr may be accessed by by the host program. + * The type of \p data must be void **. + * If there exists no host pointer value through which + * the host program may directly access \p ptr then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Except in the exceptional disjoint addressing cases discussed + * below, the value returned in \p *data will equal the input + * value \p ptr. + * + * - ::CU_POINTER_ATTRIBUTE_P2P_TOKENS: + * + * Returns in \p *data two tokens for use with the nv-p2p.h Linux + * kernel interface. \p data must be a struct of type + * CUDA_POINTER_ATTRIBUTE_P2P_TOKENS. + * + * \p ptr must be a pointer to memory obtained from :cuMemAlloc(). + * Note that p2pToken and vaSpaceToken are only valid for the + * lifetime of the source allocation. A subsequent allocation at + * the same address may return completely different tokens. + * Querying this attribute has a side effect of setting the attribute + * ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS for the region of memory that + * \p ptr points to. + * + * - ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS: + * + * A boolean attribute which when set, ensures that synchronous memory operations + * initiated on the region of memory that \p ptr points to will always synchronize. + * See further documentation in the section titled "API synchronization behavior" + * to learn more about cases when synchronous memory operations can + * exhibit asynchronous behavior. + * + * - ::CU_POINTER_ATTRIBUTE_BUFFER_ID: + * + * Returns in \p *data a buffer ID which is guaranteed to be unique within the process. + * \p data must point to an unsigned long long. + * + * \p ptr must be a pointer to memory obtained from a CUDA memory allocation API. + * Every memory allocation from any of the CUDA memory allocation APIs will + * have a unique ID over a process lifetime. Subsequent allocations do not reuse IDs + * from previous freed allocations. IDs are only unique within a single process. + * + * + * - ::CU_POINTER_ATTRIBUTE_IS_MANAGED: + * + * Returns in \p *data a boolean that indicates whether the pointer points to + * managed memory or not. + * + * If \p ptr is not a valid CUDA pointer then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * - ::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL: + * + * Returns in \p *data an integer representing a device ordinal of a device against + * which the memory was allocated or registered. + * + * - ::CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE: + * + * Returns in \p *data a boolean that indicates if this pointer maps to + * an allocation that is suitable for ::cudaIpcGetMemHandle. + * + * - ::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR: + * + * Returns in \p *data the starting address for the allocation referenced + * by the device pointer \p ptr. Note that this is not necessarily the + * address of the mapped region, but the address of the mappable address + * range \p ptr references (e.g. from ::cuMemAddressReserve). + * + * - ::CU_POINTER_ATTRIBUTE_RANGE_SIZE: + * + * Returns in \p *data the size for the allocation referenced by the device + * pointer \p ptr. Note that this is not necessarily the size of the mapped + * region, but the size of the mappable address range \p ptr references + * (e.g. from ::cuMemAddressReserve). To retrieve the size of the mapped + * region, see ::cuMemGetAddressRange + * + * - ::CU_POINTER_ATTRIBUTE_MAPPED: + * + * Returns in \p *data a boolean that indicates if this pointer is in a + * valid address range that is mapped to a backing allocation. + * + * - ::CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES: + * + * Returns a bitmask of the allowed handle types for an allocation that may + * be passed to ::cuMemExportToShareableHandle. + * + * - ::CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE: + * + * Returns in \p *data the handle to the mempool that the allocation was obtained from. + * + * \par + * + * Note that for most allocations in the unified virtual address space + * the host and device pointer for accessing the allocation will be the + * same. The exceptions to this are + * - user memory registered using ::cuMemHostRegister + * - host memory allocated using ::cuMemHostAlloc with the + * ::CU_MEMHOSTALLOC_WRITECOMBINED flag + * For these types of allocation there will exist separate, disjoint host + * and device addresses for accessing the allocation. In particular + * - The host address will correspond to an invalid unmapped device address + * (which will result in an exception if accessed from the device) + * - The device address will correspond to an invalid unmapped host address + * (which will result in an exception if accessed from the host). + * For these types of allocations, querying ::CU_POINTER_ATTRIBUTE_HOST_POINTER + * and ::CU_POINTER_ATTRIBUTE_DEVICE_POINTER may be used to retrieve the host + * and device addresses from either address. + * + * \param data - Returned pointer attribute value + * \param attribute - Pointer attribute to query + * \param ptr - Pointer + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuPointerSetAttribute, + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuMemAllocHost, + * ::cuMemFreeHost, + * ::cuMemHostAlloc, + * ::cuMemHostRegister, + * ::cuMemHostUnregister, + * ::cudaPointerGetAttributes + */ +CUresult CUDAAPI cuPointerGetAttribute(void *data, CUpointer_attribute attribute, CUdeviceptr ptr); + +/** + * \brief Prefetches memory to the specified destination device + * + * Prefetches memory to the specified destination device. \p devPtr is the + * base device pointer of the memory to be prefetched and \p dstDevice is the + * destination device. \p count specifies the number of bytes to copy. \p hStream + * is the stream in which the operation is enqueued. The memory range must refer + * to managed memory allocated via ::cuMemAllocManaged or declared via __managed__ variables. + * + * Passing in CU_DEVICE_CPU for \p dstDevice will prefetch the data to host memory. If + * \p dstDevice is a GPU, then the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS + * must be non-zero. Additionally, \p hStream must be associated with a device that has a + * non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. + * + * The start address and end address of the memory range will be rounded down and rounded up + * respectively to be aligned to CPU page size before the prefetch operation is enqueued + * in the stream. + * + * If no physical memory has been allocated for this region, then this memory region + * will be populated and mapped on the destination device. If there's insufficient + * memory to prefetch the desired region, the Unified Memory driver may evict pages from other + * ::cuMemAllocManaged allocations to host memory in order to make room. Device memory + * allocated using ::cuMemAlloc or ::cuArrayCreate will not be evicted. + * + * By default, any mappings to the previous location of the migrated pages are removed and + * mappings for the new location are only setup on \p dstDevice. The exact behavior however + * also depends on the settings applied to this memory range via ::cuMemAdvise as described + * below: + * + * If ::CU_MEM_ADVISE_SET_READ_MOSTLY was set on any subset of this memory range, + * then that subset will create a read-only copy of the pages on \p dstDevice. + * + * If ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION was called on any subset of this memory + * range, then the pages will be migrated to \p dstDevice even if \p dstDevice is not the + * preferred location of any pages in the memory range. + * + * If ::CU_MEM_ADVISE_SET_ACCESSED_BY was called on any subset of this memory range, + * then mappings to those pages from all the appropriate processors are updated to + * refer to the new location if establishing such a mapping is possible. Otherwise, + * those mappings are cleared. + * + * Note that this API is not required for functionality and only serves to improve performance + * by allowing the application to migrate data to a suitable location before it is accessed. + * Memory accesses to this range are always coherent and are allowed even when the data is + * actively being migrated. + * + * Note that this function is asynchronous with respect to the host and all work + * on other devices. + * + * \param devPtr - Pointer to be prefetched + * \param count - Size in bytes + * \param dstDevice - Destination device to prefetch to + * \param hStream - Stream to enqueue prefetch operation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpy, ::cuMemcpyPeer, ::cuMemcpyAsync, + * ::cuMemcpy3DPeerAsync, ::cuMemAdvise, + * ::cudaMemPrefetchAsync + */ +CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, CUdevice dstDevice, CUstream hStream); + +/** + * \brief Advise about the usage of a given memory range + * + * Advise the Unified Memory subsystem about the usage pattern for the memory range + * starting at \p devPtr with a size of \p count bytes. The start address and end address of the memory + * range will be rounded down and rounded up respectively to be aligned to CPU page size before the + * advice is applied. The memory range must refer to managed memory allocated via ::cuMemAllocManaged + * or declared via __managed__ variables. The memory range could also refer to system-allocated pageable + * memory provided it represents a valid, host-accessible region of memory and all additional constraints + * imposed by \p advice as outlined below are also satisfied. Specifying an invalid system-allocated pageable + * memory range results in an error being returned. + * + * The \p advice parameter can take the following values: + * - ::CU_MEM_ADVISE_SET_READ_MOSTLY: This implies that the data is mostly going to be read + * from and only occasionally written to. Any read accesses from any processor to this region will create a + * read-only copy of at least the accessed pages in that processor's memory. Additionally, if ::cuMemPrefetchAsync + * is called on this region, it will create a read-only copy of the data on the destination processor. + * If any processor writes to this region, all copies of the corresponding page will be invalidated + * except for the one where the write occurred. The \p device argument is ignored for this advice. + * Note that for a page to be read-duplicated, the accessing processor must either be the CPU or a GPU + * that has a non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. + * Also, if a context is created on a device that does not have the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS set, then read-duplication will not occur until + * all such contexts are destroyed. + * If the memory region refers to valid system-allocated pageable memory, then the accessing device must + * have a non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS for a read-only + * copy to be created on that device. Note however that if the accessing device also has a non-zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, then setting this advice + * will not create a read-only copy when that device accesses this memory region. + * + * - ::CU_MEM_ADVISE_UNSET_READ_MOSTLY: Undoes the effect of ::CU_MEM_ADVISE_SET_READ_MOSTLY and also prevents the + * Unified Memory driver from attempting heuristic read-duplication on the memory range. Any read-duplicated + * copies of the data will be collapsed into a single copy. The location for the collapsed + * copy will be the preferred location if the page has a preferred location and one of the read-duplicated + * copies was resident at that location. Otherwise, the location chosen is arbitrary. + * + * - ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION: This advice sets the preferred location for the + * data to be the memory belonging to \p device. Passing in CU_DEVICE_CPU for \p device sets the + * preferred location as host memory. If \p device is a GPU, then it must have a non-zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. Setting the preferred location + * does not cause data to migrate to that location immediately. Instead, it guides the migration policy + * when a fault occurs on that memory region. If the data is already in its preferred location and the + * faulting processor can establish a mapping without requiring the data to be migrated, then + * data migration will be avoided. On the other hand, if the data is not in its preferred location + * or if a direct mapping cannot be established, then it will be migrated to the processor accessing + * it. It is important to note that setting the preferred location does not prevent data prefetching + * done using ::cuMemPrefetchAsync. + * Having a preferred location can override the page thrash detection and resolution logic in the Unified + * Memory driver. Normally, if a page is detected to be constantly thrashing between for example host and device + * memory, the page may eventually be pinned to host memory by the Unified Memory driver. But + * if the preferred location is set as device memory, then the page will continue to thrash indefinitely. + * If ::CU_MEM_ADVISE_SET_READ_MOSTLY is also set on this memory region or any subset of it, then the + * policies associated with that advice will override the policies of this advice, unless read accesses from + * \p device will not result in a read-only copy being created on that device as outlined in description for + * the advice ::CU_MEM_ADVISE_SET_READ_MOSTLY. + * If the memory region refers to valid system-allocated pageable memory, then \p device must have a non-zero + * value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. Additionally, if \p device has + * a non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, + * then this call has no effect. Note however that this behavior may change in the future. + * + * - ::CU_MEM_ADVISE_UNSET_PREFERRED_LOCATION: Undoes the effect of ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION + * and changes the preferred location to none. + * + * - ::CU_MEM_ADVISE_SET_ACCESSED_BY: This advice implies that the data will be accessed by \p device. + * Passing in ::CU_DEVICE_CPU for \p device will set the advice for the CPU. If \p device is a GPU, then + * the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS must be non-zero. + * This advice does not cause data migration and has no impact on the location of the data per se. Instead, + * it causes the data to always be mapped in the specified processor's page tables, as long as the + * location of the data permits a mapping to be established. If the data gets migrated for any reason, + * the mappings are updated accordingly. + * This advice is recommended in scenarios where data locality is not important, but avoiding faults is. + * Consider for example a system containing multiple GPUs with peer-to-peer access enabled, where the + * data located on one GPU is occasionally accessed by peer GPUs. In such scenarios, migrating data + * over to the other GPUs is not as important because the accesses are infrequent and the overhead of + * migration may be too high. But preventing faults can still help improve performance, and so having + * a mapping set up in advance is useful. Note that on CPU access of this data, the data may be migrated + * to host memory because the CPU typically cannot access device memory directly. Any GPU that had the + * ::CU_MEM_ADVISE_SET_ACCESSED_BY flag set for this data will now have its mapping updated to point to the + * page in host memory. + * If ::CU_MEM_ADVISE_SET_READ_MOSTLY is also set on this memory region or any subset of it, then the + * policies associated with that advice will override the policies of this advice. Additionally, if the + * preferred location of this memory region or any subset of it is also \p device, then the policies + * associated with ::CU_MEM_ADVISE_SET_PREFERRED_LOCATION will override the policies of this advice. + * If the memory region refers to valid system-allocated pageable memory, then \p device must have a non-zero + * value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. Additionally, if \p device has + * a non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, + * then this call has no effect. + * + * - ::CU_MEM_ADVISE_UNSET_ACCESSED_BY: Undoes the effect of ::CU_MEM_ADVISE_SET_ACCESSED_BY. Any mappings to + * the data from \p device may be removed at any time causing accesses to result in non-fatal page faults. + * If the memory region refers to valid system-allocated pageable memory, then \p device must have a non-zero + * value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. Additionally, if \p device has + * a non-zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS_USES_HOST_PAGE_TABLES, + * then this call has no effect. + * + * \param devPtr - Pointer to memory to set the advice for + * \param count - Size in bytes of the memory range + * \param advice - Advice to be applied for the specified memory range + * \param device - Device to apply the advice for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemcpy, ::cuMemcpyPeer, ::cuMemcpyAsync, + * ::cuMemcpy3DPeerAsync, ::cuMemPrefetchAsync, + * ::cudaMemAdvise + */ +CUresult CUDAAPI cuMemAdvise(CUdeviceptr devPtr, size_t count, CUmem_advise advice, CUdevice device); + +/** + * \brief Query an attribute of a given memory range + * + * Query an attribute about the memory range starting at \p devPtr with a size of \p count bytes. The + * memory range must refer to managed memory allocated via ::cuMemAllocManaged or declared via + * __managed__ variables. + * + * The \p attribute parameter can take the following values: + * - ::CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY: If this attribute is specified, \p data will be interpreted + * as a 32-bit integer, and \p dataSize must be 4. The result returned will be 1 if all pages in the given + * memory range have read-duplication enabled, or 0 otherwise. + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION: If this attribute is specified, \p data will be + * interpreted as a 32-bit integer, and \p dataSize must be 4. The result returned will be a GPU device + * id if all pages in the memory range have that GPU as their preferred location, or it will be CU_DEVICE_CPU + * if all pages in the memory range have the CPU as their preferred location, or it will be CU_DEVICE_INVALID + * if either all the pages don't have the same preferred location or some of the pages don't have a + * preferred location at all. Note that the actual location of the pages in the memory range at the time of + * the query may be different from the preferred location. + * - ::CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY: If this attribute is specified, \p data will be interpreted + * as an array of 32-bit integers, and \p dataSize must be a non-zero multiple of 4. The result returned + * will be a list of device ids that had ::CU_MEM_ADVISE_SET_ACCESSED_BY set for that entire memory range. + * If any device does not have that advice set for the entire memory range, that device will not be included. + * If \p data is larger than the number of devices that have that advice set for that memory range, + * CU_DEVICE_INVALID will be returned in all the extra space provided. For ex., if \p dataSize is 12 + * (i.e. \p data has 3 elements) and only device 0 has the advice set, then the result returned will be + * { 0, CU_DEVICE_INVALID, CU_DEVICE_INVALID }. If \p data is smaller than the number of devices that have + * that advice set, then only as many devices will be returned as can fit in the array. There is no + * guarantee on which specific devices will be returned, however. + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION: If this attribute is specified, \p data will be + * interpreted as a 32-bit integer, and \p dataSize must be 4. The result returned will be the last location + * to which all pages in the memory range were prefetched explicitly via ::cuMemPrefetchAsync. This will either be + * a GPU id or CU_DEVICE_CPU depending on whether the last location for prefetch was a GPU or the CPU + * respectively. If any page in the memory range was never explicitly prefetched or if all pages were not + * prefetched to the same location, CU_DEVICE_INVALID will be returned. Note that this simply returns the + * last location that the applicaton requested to prefetch the memory range to. It gives no indication as to + * whether the prefetch operation to that location has completed or even begun. + * + * \param data - A pointers to a memory location where the result + * of each attribute query will be written to. + * \param dataSize - Array containing the size of data + * \param attribute - The attribute to query + * \param devPtr - Start of the range to query + * \param count - Size of the range to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * \note_async + * \note_null_stream + * + * \sa ::cuMemRangeGetAttributes, ::cuMemPrefetchAsync, + * ::cuMemAdvise, + * ::cudaMemRangeGetAttribute + */ +CUresult CUDAAPI cuMemRangeGetAttribute(void *data, size_t dataSize, CUmem_range_attribute attribute, CUdeviceptr devPtr, size_t count); + +/** + * \brief Query attributes of a given memory range. + * + * Query attributes of the memory range starting at \p devPtr with a size of \p count bytes. The + * memory range must refer to managed memory allocated via ::cuMemAllocManaged or declared via + * __managed__ variables. The \p attributes array will be interpreted to have \p numAttributes + * entries. The \p dataSizes array will also be interpreted to have \p numAttributes entries. + * The results of the query will be stored in \p data. + * + * The list of supported attributes are given below. Please refer to ::cuMemRangeGetAttribute for + * attribute descriptions and restrictions. + * + * - ::CU_MEM_RANGE_ATTRIBUTE_READ_MOSTLY + * - ::CU_MEM_RANGE_ATTRIBUTE_PREFERRED_LOCATION + * - ::CU_MEM_RANGE_ATTRIBUTE_ACCESSED_BY + * - ::CU_MEM_RANGE_ATTRIBUTE_LAST_PREFETCH_LOCATION + * + * \param data - A two-dimensional array containing pointers to memory + * locations where the result of each attribute query will be written to. + * \param dataSizes - Array containing the sizes of each result + * \param attributes - An array of attributes to query + * (numAttributes and the number of attributes in this array should match) + * \param numAttributes - Number of attributes to query + * \param devPtr - Start of the range to query + * \param count - Size of the range to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa ::cuMemRangeGetAttribute, ::cuMemAdvise, + * ::cuMemPrefetchAsync, + * ::cudaMemRangeGetAttributes + */ +CUresult CUDAAPI cuMemRangeGetAttributes(void **data, size_t *dataSizes, CUmem_range_attribute *attributes, size_t numAttributes, CUdeviceptr devPtr, size_t count); + +/** + * \brief Set attributes on a previously allocated memory region + * + * The supported attributes are: + * + * - ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS: + * + * A boolean attribute that can either be set (1) or unset (0). When set, + * the region of memory that \p ptr points to is guaranteed to always synchronize + * memory operations that are synchronous. If there are some previously initiated + * synchronous memory operations that are pending when this attribute is set, the + * function does not return until those memory operations are complete. + * See further documentation in the section titled "API synchronization behavior" + * to learn more about cases when synchronous memory operations can + * exhibit asynchronous behavior. + * \p value will be considered as a pointer to an unsigned integer to which this attribute is to be set. + * + * \param value - Pointer to memory containing the value to be set + * \param attribute - Pointer attribute to set + * \param ptr - Pointer to a memory region allocated using CUDA memory allocation APIs + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa ::cuPointerGetAttribute, + * ::cuPointerGetAttributes, + * ::cuMemAlloc, + * ::cuMemFree, + * ::cuMemAllocHost, + * ::cuMemFreeHost, + * ::cuMemHostAlloc, + * ::cuMemHostRegister, + * ::cuMemHostUnregister + */ +CUresult CUDAAPI cuPointerSetAttribute(const void *value, CUpointer_attribute attribute, CUdeviceptr ptr); + +/** + * \brief Returns information about a pointer. + * + * The supported attributes are (refer to ::cuPointerGetAttribute for attribute descriptions and restrictions): + * + * - ::CU_POINTER_ATTRIBUTE_CONTEXT + * - ::CU_POINTER_ATTRIBUTE_MEMORY_TYPE + * - ::CU_POINTER_ATTRIBUTE_DEVICE_POINTER + * - ::CU_POINTER_ATTRIBUTE_HOST_POINTER + * - ::CU_POINTER_ATTRIBUTE_SYNC_MEMOPS + * - ::CU_POINTER_ATTRIBUTE_BUFFER_ID + * - ::CU_POINTER_ATTRIBUTE_IS_MANAGED + * - ::CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL + * - ::CU_POINTER_ATTRIBUTE_RANGE_START_ADDR + * - ::CU_POINTER_ATTRIBUTE_RANGE_SIZE + * - ::CU_POINTER_ATTRIBUTE_MAPPED + * - ::CU_POINTER_ATTRIBUTE_IS_LEGACY_CUDA_IPC_CAPABLE + * - ::CU_POINTER_ATTRIBUTE_ALLOWED_HANDLE_TYPES + * - ::CU_POINTER_ATTRIBUTE_MEMPOOL_HANDLE + * + * \param numAttributes - Number of attributes to query + * \param attributes - An array of attributes to query + * (numAttributes and the number of attributes in this array should match) + * \param data - A two-dimensional array containing pointers to memory + * locations where the result of each attribute query will be written to. + * \param ptr - Pointer to query + * + * Unlike ::cuPointerGetAttribute, this function will not return an error when the \p ptr + * encountered is not a valid CUDA pointer. Instead, the attributes are assigned default NULL values + * and CUDA_SUCCESS is returned. + * + * If \p ptr was not allocated by, mapped by, or registered with a ::CUcontext which uses UVA + * (Unified Virtual Addressing), ::CUDA_ERROR_INVALID_CONTEXT is returned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuPointerGetAttribute, + * ::cuPointerSetAttribute, + * ::cudaPointerGetAttributes + */ +CUresult CUDAAPI cuPointerGetAttributes(unsigned int numAttributes, CUpointer_attribute *attributes, void **data, CUdeviceptr ptr); + +/** @} */ /* END CUDA_UNIFIED */ + +/** + * \defgroup CUDA_STREAM Stream Management + * + * ___MANBRIEF___ stream management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Create a stream + * + * Creates a stream and returns a handle in \p phStream. The \p Flags argument + * determines behaviors of the stream. + * + * Valid values for \p Flags are: + * - ::CU_STREAM_DEFAULT: Default stream creation flag. + * - ::CU_STREAM_NON_BLOCKING: Specifies that work running in the created + * stream may run concurrently with work in stream 0 (the NULL stream), and that + * the created stream should perform no implicit synchronization with stream 0. + * + * \param phStream - Returned newly created stream + * \param Flags - Parameters for stream creation + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreate, + * ::cudaStreamCreateWithFlags + */ +CUresult CUDAAPI cuStreamCreate(CUstream *phStream, unsigned int Flags); + +/** + * \brief Create a stream with the given priority + * + * Creates a stream with the specified priority and returns a handle in \p phStream. + * This API alters the scheduler priority of work in the stream. Work in a higher + * priority stream may preempt work already executing in a low priority stream. + * + * \p priority follows a convention where lower numbers represent higher priorities. + * '0' represents default priority. The range of meaningful numerical priorities can + * be queried using ::cuCtxGetStreamPriorityRange. If the specified priority is + * outside the numerical range returned by ::cuCtxGetStreamPriorityRange, + * it will automatically be clamped to the lowest or the highest number in the range. + * + * \param phStream - Returned newly created stream + * \param flags - Flags for stream creation. See ::cuStreamCreate for a list of + * valid flags + * \param priority - Stream priority. Lower numbers represent higher priorities. + * See ::cuCtxGetStreamPriorityRange for more information about + * meaningful stream priorities that can be passed. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \note Stream priorities are supported only on GPUs + * with compute capability 3.5 or higher. + * + * \note In the current implementation, only compute kernels launched in + * priority streams are affected by the stream's priority. Stream priorities have + * no effect on host-to-device and device-to-host memory operations. + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamGetPriority, + * ::cuCtxGetStreamPriorityRange, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreateWithPriority + */ +CUresult CUDAAPI cuStreamCreateWithPriority(CUstream *phStream, unsigned int flags, int priority); + + +/** + * \brief Query the priority of a given stream + * + * Query the priority of a stream created using ::cuStreamCreate or ::cuStreamCreateWithPriority + * and return the priority in \p priority. Note that if the stream was created with a + * priority outside the numerical range returned by ::cuCtxGetStreamPriorityRange, + * this function returns the clamped priority. + * See ::cuStreamCreateWithPriority for details about priority clamping. + * + * \param hStream - Handle to the stream to be queried + * \param priority - Pointer to a signed integer in which the stream's priority is returned + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamCreateWithPriority, + * ::cuCtxGetStreamPriorityRange, + * ::cuStreamGetFlags, + * ::cudaStreamGetPriority + */ +CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); + +/** + * \brief Query the flags of a given stream + * + * Query the flags of a stream created using ::cuStreamCreate or ::cuStreamCreateWithPriority + * and return the flags in \p flags. + * + * \param hStream - Handle to the stream to be queried + * \param flags - Pointer to an unsigned integer in which the stream's flags are returned + * The value returned in \p flags is a logical 'OR' of all flags that + * were used while creating this stream. See ::cuStreamCreate for the list + * of valid flags + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreate, + * ::cuStreamGetPriority, + * ::cudaStreamGetFlags + */ +CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); + +/** + * \brief Query the context associated with a stream + * + * Returns the CUDA context that the stream is associated with. + * + * The stream handle \p hStream can refer to any of the following: + *
    + *
  • a stream created via any of the CUDA driver APIs such as ::cuStreamCreate + * and ::cuStreamCreateWithPriority, or their runtime API equivalents such as + * ::cudaStreamCreate, ::cudaStreamCreateWithFlags and ::cudaStreamCreateWithPriority. + * The returned context is the context that was active in the calling thread when the + * stream was created. Passing an invalid handle will result in undefined behavior.
  • + *
  • any of the special streams such as the NULL stream, ::CU_STREAM_LEGACY and + * ::CU_STREAM_PER_THREAD. The runtime API equivalents of these are also accepted, + * which are NULL, ::cudaStreamLegacy and ::cudaStreamPerThread respectively. + * Specifying any of the special handles will return the context current to the + * calling thread. If no context is current to the calling thread, + * ::CUDA_ERROR_INVALID_CONTEXT is returned.
  • + *
+ * + * \param hStream - Handle to the stream to be queried + * \param pctx - Returned context associated with the stream + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * \notefnerr + * + * \sa ::cuStreamDestroy, + * ::cuStreamCreateWithPriority, + * ::cuStreamGetPriority, + * ::cuStreamGetFlags, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamCreate, + * ::cudaStreamCreateWithFlags + */ +CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); + +/** + * \brief Make a compute stream wait on an event + * + * Makes all future work submitted to \p hStream wait for all work captured in + * \p hEvent. See ::cuEventRecord() for details on what is captured by an event. + * The synchronization will be performed efficiently on the device when applicable. + * \p hEvent may be from a different context or device than \p hStream. + * + * flags include: + * - ::CU_EVENT_WAIT_DEFAULT: Default event creation flag. + * - ::CU_EVENT_WAIT_EXTERNAL: Event is captured in the graph as an external + * event node when performing stream capture. This flag is invalid outside + * of stream capture. + * + * \param hStream - Stream to wait + * \param hEvent - Event to wait on (may not be NULL) + * \param Flags - See ::CUevent_capture_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuEventRecord, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cuStreamDestroy, + * ::cudaStreamWaitEvent + */ +CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, unsigned int Flags); + +/** + * \brief Add a callback to a compute stream + * + * \note This function is slated for eventual deprecation and removal. If + * you do not require the callback to execute in case of a device error, + * consider using ::cuLaunchHostFunc. Additionally, this function is not + * supported with ::cuStreamBeginCapture and ::cuStreamEndCapture, unlike + * ::cuLaunchHostFunc. + * + * Adds a callback to be called on the host after all currently enqueued + * items in the stream have completed. For each + * cuStreamAddCallback call, the callback will be executed exactly once. + * The callback will block later work in the stream until it is finished. + * + * The callback may be passed ::CUDA_SUCCESS or an error code. In the event + * of a device error, all subsequently executed callbacks will receive an + * appropriate ::CUresult. + * + * Callbacks must not make any CUDA API calls. Attempting to use a CUDA API + * will result in ::CUDA_ERROR_NOT_PERMITTED. Callbacks must not perform any + * synchronization that may depend on outstanding device work or other callbacks + * that are not mandated to run earlier. Callbacks without a mandated order + * (in independent streams) execute in undefined order and may be serialized. + * + * For the purposes of Unified Memory, callback execution makes a number of + * guarantees: + *
    + *
  • The callback stream is considered idle for the duration of the + * callback. Thus, for example, a callback may always use memory attached + * to the callback stream.
  • + *
  • The start of execution of a callback has the same effect as + * synchronizing an event recorded in the same stream immediately prior to + * the callback. It thus synchronizes streams which have been "joined" + * prior to the callback.
  • + *
  • Adding device work to any stream does not have the effect of making + * the stream active until all preceding host functions and stream callbacks + * have executed. Thus, for + * example, a callback might use global attached memory even if work has + * been added to another stream, if the work has been ordered behind the + * callback with an event.
  • + *
  • Completion of a callback does not cause a stream to become + * active except as described above. The callback stream will remain idle + * if no device work follows the callback, and will remain idle across + * consecutive callbacks without device work in between. Thus, for example, + * stream synchronization can be done by signaling from a callback at the + * end of the stream.
  • + *
+ * + * \param hStream - Stream to add callback to + * \param callback - The function to call once preceding stream operations are complete + * \param userData - User specified data to be passed to the callback function + * \param flags - Reserved for future use, must be 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuMemAllocManaged, + * ::cuStreamAttachMemAsync, + * ::cuStreamLaunchHostFunc, + * ::cudaStreamAddCallback + */ +CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, CUstreamCallback callback, void *userData, unsigned int flags); + +/** + * \brief Begins graph capture on a stream + * + * Begin graph capture on \p hStream. When a stream is in capture mode, all operations + * pushed into the stream will not be executed, but will instead be captured into + * a graph, which will be returned via ::cuStreamEndCapture. Capture may not be initiated + * if \p stream is CU_STREAM_LEGACY. Capture must be ended on the same stream in which + * it was initiated, and it may only be initiated if the stream is not already in capture + * mode. The capture mode may be queried via ::cuStreamIsCapturing. A unique id + * representing the capture sequence may be queried via ::cuStreamGetCaptureInfo. + * + * If \p mode is not ::CU_STREAM_CAPTURE_MODE_RELAXED, ::cuStreamEndCapture must be + * called on this stream from the same thread. + * + * \param hStream - Stream in which to initiate capture + * \param mode - Controls the interaction of this capture sequence with other API + * calls that are potentially unsafe. For more details see + * ::cuThreadExchangeStreamCaptureMode. + * + * \note Kernels captured using this API must not use texture and surface references. + * Reading or writing through any texture or surface reference is undefined + * behavior. This restriction does not apply to texture and surface objects. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuStreamCreate, + * ::cuStreamIsCapturing, + * ::cuStreamEndCapture, + * ::cuThreadExchangeStreamCaptureMode + */ +CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream, CUstreamCaptureMode mode); + +/** + * \brief Swaps the stream capture interaction mode for a thread + * + * Sets the calling thread's stream capture interaction mode to the value contained + * in \p *mode, and overwrites \p *mode with the previous mode for the thread. To + * facilitate deterministic behavior across function or module boundaries, callers + * are encouraged to use this API in a push-pop fashion: \code + CUstreamCaptureMode mode = desiredMode; + cuThreadExchangeStreamCaptureMode(&mode); + ... + cuThreadExchangeStreamCaptureMode(&mode); // restore previous mode + * \endcode + * + * During stream capture (see ::cuStreamBeginCapture), some actions, such as a call + * to ::cudaMalloc, may be unsafe. In the case of ::cudaMalloc, the operation is + * not enqueued asynchronously to a stream, and is not observed by stream capture. + * Therefore, if the sequence of operations captured via ::cuStreamBeginCapture + * depended on the allocation being replayed whenever the graph is launched, the + * captured graph would be invalid. + * + * Therefore, stream capture places restrictions on API calls that can be made within + * or concurrently to a ::cuStreamBeginCapture-::cuStreamEndCapture sequence. This + * behavior can be controlled via this API and flags to ::cuStreamBeginCapture. + * + * A thread's mode is one of the following: + * - \p CU_STREAM_CAPTURE_MODE_GLOBAL: This is the default mode. If the local thread has + * an ongoing capture sequence that was not initiated with + * \p CU_STREAM_CAPTURE_MODE_RELAXED at \p cuStreamBeginCapture, or if any other thread + * has a concurrent capture sequence initiated with \p CU_STREAM_CAPTURE_MODE_GLOBAL, + * this thread is prohibited from potentially unsafe API calls. + * - \p CU_STREAM_CAPTURE_MODE_THREAD_LOCAL: If the local thread has an ongoing capture + * sequence not initiated with \p CU_STREAM_CAPTURE_MODE_RELAXED, it is prohibited + * from potentially unsafe API calls. Concurrent capture sequences in other threads + * are ignored. + * - \p CU_STREAM_CAPTURE_MODE_RELAXED: The local thread is not prohibited from potentially + * unsafe API calls. Note that the thread is still prohibited from API calls which + * necessarily conflict with stream capture, for example, attempting ::cuEventQuery + * on an event that was last recorded inside a capture sequence. + * + * \param mode - Pointer to mode value to swap with the current mode + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuStreamBeginCapture + */ +CUresult CUDAAPI cuThreadExchangeStreamCaptureMode(CUstreamCaptureMode *mode); + +/** + * \brief Ends capture on a stream, returning the captured graph + * + * End capture on \p hStream, returning the captured graph via \p phGraph. + * Capture must have been initiated on \p hStream via a call to ::cuStreamBeginCapture. + * If capture was invalidated, due to a violation of the rules of stream capture, then + * a NULL graph will be returned. + * + * If the \p mode argument to ::cuStreamBeginCapture was not + * ::CU_STREAM_CAPTURE_MODE_RELAXED, this call must be from the same thread as + * ::cuStreamBeginCapture. + * + * \param hStream - Stream to query + * \param phGraph - The captured graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_WRONG_THREAD + * \notefnerr + * + * \sa + * ::cuStreamCreate, + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing + */ +CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); + +/** + * \brief Returns a stream's capture status + * + * Return the capture status of \p hStream via \p captureStatus. After a successful + * call, \p *captureStatus will contain one of the following: + * - ::CU_STREAM_CAPTURE_STATUS_NONE: The stream is not capturing. + * - ::CU_STREAM_CAPTURE_STATUS_ACTIVE: The stream is capturing. + * - ::CU_STREAM_CAPTURE_STATUS_INVALIDATED: The stream was capturing but an error + * has invalidated the capture sequence. The capture sequence must be terminated + * with ::cuStreamEndCapture on the stream where it was initiated in order to + * continue using \p hStream. + * + * Note that, if this is called on ::CU_STREAM_LEGACY (the "null stream") while + * a blocking stream in the same context is capturing, it will return + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT and \p *captureStatus is unspecified + * after the call. The blocking stream capture is not invalidated. + * + * When a blocking stream is capturing, the legacy stream is in an + * unusable state until the blocking stream capture is terminated. The legacy + * stream is not supported for stream capture, but attempted use would have an + * implicit dependency on the capturing stream(s). + * + * \param hStream - Stream to query + * \param captureStatus - Returns the stream's capture status + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \notefnerr + * + * \sa + * ::cuStreamCreate, + * ::cuStreamBeginCapture, + * ::cuStreamEndCapture + */ +CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *captureStatus); + +/** + * \brief Query capture status of a stream + * + * Note there is a later version of this API, ::cuStreamGetCaptureInfo_v2. It will + * supplant this version in 12.0, which is retained for minor version compatibility. + * + * Query the capture status of a stream and and get an id for + * the capture sequence, which is unique over the lifetime of the process. + * + * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not created + * with ::CU_STREAM_NON_BLOCKING is capturing, returns ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. + * + * A valid id is returned only if both of the following are true: + * - the call returns CUDA_SUCCESS + * - captureStatus is set to ::CU_STREAM_CAPTURE_STATUS_ACTIVE + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \notefnerr + * + * \sa + * ::cuStreamGetCaptureInfo_v2, + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing + */ +CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); + +/** + * \brief Query a stream's capture state (11.3+) + * + * Query stream state related to stream capture. + * + * If called on ::CU_STREAM_LEGACY (the "null stream") while a stream not created + * with ::CU_STREAM_NON_BLOCKING is capturing, returns ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT. + * + * Valid data (other than capture status) is returned only if both of the following are true: + * - the call returns CUDA_SUCCESS + * - the returned capture status is ::CU_STREAM_CAPTURE_STATUS_ACTIVE + * + * This version of cuStreamGetCaptureInfo is introduced in CUDA 11.3 and will supplant the + * previous version in 12.0. Developers requiring compatibility across minor versions to + * CUDA 11.0 (driver version 445) should use ::cuStreamGetCaptureInfo or include a fallback + * path. + * + * \param hStream - The stream to query + * \param captureStatus_out - Location to return the capture status of the stream; required + * \param id_out - Optional location to return an id for the capture sequence, which is + * unique over the lifetime of the process + * \param graph_out - Optional location to return the graph being captured into. All + * operations other than destroy and node removal are permitted on the graph + * while the capture sequence is in progress. This API does not transfer + * ownership of the graph, which is transferred or destroyed at + * ::cuStreamEndCapture. Note that the graph handle may be invalidated before + * end of capture for certain errors. Nodes that are or become + * unreachable from the original stream at ::cuStreamEndCapture due to direct + * actions on the graph do not trigger ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED. + * \param dependencies_out - Optional location to store a pointer to an array of nodes. + * The next node to be captured in the stream will depend on this set of nodes, + * absent operations such as event wait which modify this set. The array pointer + * is valid until the next API call which operates on the stream or until end of + * capture. The node handles may be copied out and are valid until they or the + * graph is destroyed. The driver-owned array may also be passed directly to + * APIs that operate on the graph (not the stream) without copying. + * \param numDependencies_out - Optional location to store the size of the array + * returned in dependencies_out. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_STREAM_CAPTURE_IMPLICIT + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuStreamGetCaptureInfo, + * ::cuStreamBeginCapture, + * ::cuStreamIsCapturing, + * ::cuStreamUpdateCaptureDependencies + */ +CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, + cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out); + +/** + * \brief Update the set of dependencies in a capturing stream (11.3+) + * + * Modifies the dependency set of a capturing stream. The dependency set is the set + * of nodes that the next captured node in the stream will depend on. + * + * Valid flags are ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES and + * ::CU_STREAM_SET_CAPTURE_DEPENDENCIES. These control whether the set passed to + * the API is added to the existing set or replaces it. A flags value of 0 defaults + * to ::CU_STREAM_ADD_CAPTURE_DEPENDENCIES. + * + * Nodes that are removed from the dependency set via this API do not result in + * ::CUDA_ERROR_STREAM_CAPTURE_UNJOINED if they are unreachable from the stream at + * ::cuStreamEndCapture. + * + * Returns ::CUDA_ERROR_ILLEGAL_STATE if the stream is not capturing. + * + * This API is new in CUDA 11.3. Developers requiring compatibility across minor + * versions to CUDA 11.0 should not use this API or provide a fallback. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_ILLEGAL_STATE + * + * \sa + * ::cuStreamBeginCapture, + * ::cuStreamGetCaptureInfo, + * ::cuStreamGetCaptureInfo_v2 + */ +CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags); + +/** + * \brief Attach memory to a stream asynchronously + * + * Enqueues an operation in \p hStream to specify stream association of + * \p length bytes of memory starting from \p dptr. This function is a + * stream-ordered operation, meaning that it is dependent on, and will + * only take effect when, previous work in stream has completed. Any + * previous association is automatically replaced. + * + * \p dptr must point to one of the following types of memories: + * - managed memory declared using the __managed__ keyword or allocated with + * ::cuMemAllocManaged. + * - a valid host-accessible region of system-allocated pageable memory. This + * type of memory may only be specified if the device associated with the + * stream reports a non-zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_PAGEABLE_MEMORY_ACCESS. + * + * For managed allocations, \p length must be either zero or the entire + * allocation's size. Both indicate that the entire allocation's stream + * association is being changed. Currently, it is not possible to change stream + * association for a portion of a managed allocation. + * + * For pageable host allocations, \p length must be non-zero. + * + * The stream association is specified using \p flags which must be + * one of ::CUmemAttach_flags. + * If the ::CU_MEM_ATTACH_GLOBAL flag is specified, the memory can be accessed + * by any stream on any device. + * If the ::CU_MEM_ATTACH_HOST flag is specified, the program makes a guarantee + * that it won't access the memory on the device from any stream on a device that + * has a zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. + * If the ::CU_MEM_ATTACH_SINGLE flag is specified and \p hStream is associated with + * a device that has a zero value for the device attribute ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS, + * the program makes a guarantee that it will only access the memory on the device + * from \p hStream. It is illegal to attach singly to the NULL stream, because the + * NULL stream is a virtual global stream and not a specific stream. An error will + * be returned in this case. + * + * When memory is associated with a single stream, the Unified Memory system will + * allow CPU access to this memory region so long as all operations in \p hStream + * have completed, regardless of whether other streams are active. In effect, + * this constrains exclusive ownership of the managed memory region by + * an active GPU to per-stream activity instead of whole-GPU activity. + * + * Accessing memory on the device from streams that are not associated with + * it will produce undefined results. No error checking is performed by the + * Unified Memory system to ensure that kernels launched into other streams + * do not access this region. + * + * It is a program's responsibility to order calls to ::cuStreamAttachMemAsync + * via events, synchronization or other means to ensure legal access to memory + * at all times. Data visibility and coherency will be changed appropriately + * for all kernels which follow a stream-association change. + * + * If \p hStream is destroyed while data is associated with it, the association is + * removed and the association reverts to the default visibility of the allocation + * as specified at ::cuMemAllocManaged. For __managed__ variables, the default + * association is always ::CU_MEM_ATTACH_GLOBAL. Note that destroying a stream is an + * asynchronous operation, and as a result, the change to default association won't + * happen until all work in the stream has completed. + * + * \param hStream - Stream in which to enqueue the attach operation + * \param dptr - Pointer to memory (must be a pointer to managed memory or + * to a valid host-accessible region of system-allocated + * pageable memory) + * \param length - Length of memory + * \param flags - Must be one of ::CUmemAttach_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuMemAllocManaged, + * ::cudaStreamAttachMemAsync + */ +CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, size_t length, unsigned int flags); + +/** + * \brief Determine status of a compute stream + * + * Returns ::CUDA_SUCCESS if all operations in the stream specified by + * \p hStream have completed, or ::CUDA_ERROR_NOT_READY if not. + * + * For the purposes of Unified Memory, a return value of ::CUDA_SUCCESS + * is equivalent to having called ::cuStreamSynchronize(). + * + * \param hStream - Stream to query status of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_READY + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamQuery + */ +CUresult CUDAAPI cuStreamQuery(CUstream hStream); + +/** + * \brief Wait until a stream's tasks are completed + * + * Waits until the device has completed all operations in the stream specified + * by \p hStream. If the context was created with the + * ::CU_CTX_SCHED_BLOCKING_SYNC flag, the CPU thread will block until the + * stream is finished with all of its tasks. + * + * \param hStream - Stream to wait for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamDestroy, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamAddCallback, + * ::cudaStreamSynchronize + */ +CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); + +/** + * \brief Destroys a stream + * + * Destroys the stream specified by \p hStream. + * + * In case the device is still doing work in the stream \p hStream + * when ::cuStreamDestroy() is called, the function will return immediately + * and the resources associated with \p hStream will be released automatically + * once the device has completed all work in \p hStream. + * + * \param hStream - Stream to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamWaitEvent, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamAddCallback, + * ::cudaStreamDestroy + */ +CUresult CUDAAPI cuStreamDestroy(CUstream hStream); + +/** + * \brief Copies attributes from source stream to destination stream. + * + * Copies attributes from source stream \p src to destination stream \p dst. + * Both streams must have the same context. + * + * \param[out] dst Destination stream + * \param[in] src Source stream + * For list of attributes see ::CUstreamAttrID + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamCopyAttributes(CUstream dst, CUstream src); + +/** + * \brief Queries stream attribute. + * + * Queries attribute \p attr from \p hStream and stores it in corresponding + * member of \p value_out. + * + * \param[in] hStream + * \param[in] attr + * \param[out] value_out + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, + CUstreamAttrValue *value_out); + +/** + * \brief Sets stream attribute. + * + * Sets attribute \p attr on \p hStream from corresponding attribute of + * \p value. The updated attribute will be applied to subsequent work + * submitted to the stream. It will not affect previously submitted work. + * + * \param[out] hStream + * \param[in] attr + * \param[in] value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, + const CUstreamAttrValue *value); + +/** @} */ /* END CUDA_STREAM */ + + +/** + * \defgroup CUDA_EVENT Event Management + * + * ___MANBRIEF___ event management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the event management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Creates an event + * + * Creates an event *phEvent for the current context with the flags specified via + * \p Flags. Valid flags include: + * - ::CU_EVENT_DEFAULT: Default event creation flag. + * - ::CU_EVENT_BLOCKING_SYNC: Specifies that the created event should use blocking + * synchronization. A CPU thread that uses ::cuEventSynchronize() to wait on + * an event created with this flag will block until the event has actually + * been recorded. + * - ::CU_EVENT_DISABLE_TIMING: Specifies that the created event does not need + * to record timing data. Events created with this flag specified and + * the ::CU_EVENT_BLOCKING_SYNC flag not specified will provide the best + * performance when used with ::cuStreamWaitEvent() and ::cuEventQuery(). + * - ::CU_EVENT_INTERPROCESS: Specifies that the created event may be used as an + * interprocess event by ::cuIpcGetEventHandle(). ::CU_EVENT_INTERPROCESS must + * be specified along with ::CU_EVENT_DISABLE_TIMING. + * + * \param phEvent - Returns newly created event + * \param Flags - Event creation flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \notefnerr + * + * \sa + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventCreate, + * ::cudaEventCreateWithFlags + */ +CUresult CUDAAPI cuEventCreate(CUevent *phEvent, unsigned int Flags); + +/** + * \brief Records an event + * + * Captures in \p hEvent the contents of \p hStream at the time of this call. + * \p hEvent and \p hStream must be from the same context. + * Calls such as ::cuEventQuery() or ::cuStreamWaitEvent() will then + * examine or wait for completion of the work that was captured. Uses of + * \p hStream after this call do not modify \p hEvent. See note on default + * stream behavior for what is captured in the default case. + * + * ::cuEventRecord() can be called multiple times on the same event and + * will overwrite the previously captured state. Other APIs such as + * ::cuStreamWaitEvent() use the most recently captured state at the time + * of the API call, and are not affected by later calls to + * ::cuEventRecord(). Before the first call to ::cuEventRecord(), an + * event represents an empty set of work, so for example ::cuEventQuery() + * would return ::CUDA_SUCCESS. + * + * \param hEvent - Event to record + * \param hStream - Stream to record event for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \note_null_stream + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuStreamWaitEvent, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventRecord, + * ::cuEventRecordWithFlags + */ +CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); + +/** + * \brief Records an event + * + * Captures in \p hEvent the contents of \p hStream at the time of this call. + * \p hEvent and \p hStream must be from the same context. + * Calls such as ::cuEventQuery() or ::cuStreamWaitEvent() will then + * examine or wait for completion of the work that was captured. Uses of + * \p hStream after this call do not modify \p hEvent. See note on default + * stream behavior for what is captured in the default case. + * + * ::cuEventRecordWithFlags() can be called multiple times on the same event and + * will overwrite the previously captured state. Other APIs such as + * ::cuStreamWaitEvent() use the most recently captured state at the time + * of the API call, and are not affected by later calls to + * ::cuEventRecordWithFlags(). Before the first call to ::cuEventRecordWithFlags(), an + * event represents an empty set of work, so for example ::cuEventQuery() + * would return ::CUDA_SUCCESS. + * + * flags include: + * - ::CU_EVENT_RECORD_DEFAULT: Default event creation flag. + * - ::CU_EVENT_RECORD_EXTERNAL: Event is captured in the graph as an external + * event node when performing stream capture. This flag is invalid outside + * of stream capture. + * + * \param hEvent - Event to record + * \param hStream - Stream to record event for + * \param flags - See ::CUevent_capture_flags + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \note_null_stream + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuStreamWaitEvent, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cuEventRecord, + * ::cudaEventRecord + */ +CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, unsigned int flags); + +/** + * \brief Queries an event's status + * + * Queries the status of all work currently captured by \p hEvent. See + * ::cuEventRecord() for details on what is captured by an event. + * + * Returns ::CUDA_SUCCESS if all captured work has been completed, or + * ::CUDA_ERROR_NOT_READY if any captured work is incomplete. + * + * For the purposes of Unified Memory, a return value of ::CUDA_SUCCESS + * is equivalent to having called ::cuEventSynchronize(). + * + * \param hEvent - Event to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_READY + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventSynchronize, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventQuery + */ +CUresult CUDAAPI cuEventQuery(CUevent hEvent); + +/** + * \brief Waits for an event to complete + * + * Waits until the completion of all work currently captured in \p hEvent. + * See ::cuEventRecord() for details on what is captured by an event. + * + * Waiting for an event that was created with the ::CU_EVENT_BLOCKING_SYNC + * flag will cause the calling CPU thread to block until the event has + * been completed by the device. If the ::CU_EVENT_BLOCKING_SYNC flag has + * not been set, then the CPU thread will busy-wait until the event has + * been completed by the device. + * + * \param hEvent - Event to wait for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventDestroy, + * ::cuEventElapsedTime, + * ::cudaEventSynchronize + */ +CUresult CUDAAPI cuEventSynchronize(CUevent hEvent); + +/** + * \brief Destroys an event + * + * Destroys the event specified by \p hEvent. + * + * An event may be destroyed before it is complete (i.e., while + * ::cuEventQuery() would return ::CUDA_ERROR_NOT_READY). In this case, the + * call does not block on completion of the event, and any associated + * resources will automatically be released asynchronously at completion. + * + * \param hEvent - Event to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuEventElapsedTime, + * ::cudaEventDestroy + */ +CUresult CUDAAPI cuEventDestroy(CUevent hEvent); + +/** + * \brief Computes the elapsed time between two events + * + * Computes the elapsed time between two events (in milliseconds with a + * resolution of around 0.5 microseconds). + * + * If either event was last recorded in a non-NULL stream, the resulting time + * may be greater than expected (even if both used the same stream handle). This + * happens because the ::cuEventRecord() operation takes place asynchronously + * and there is no guarantee that the measured latency is actually just between + * the two events. Any number of other different stream operations could execute + * in between the two measured events, thus altering the timing in a significant + * way. + * + * If ::cuEventRecord() has not been called on either event then + * ::CUDA_ERROR_INVALID_HANDLE is returned. If ::cuEventRecord() has been called + * on both events but one or both of them has not yet been completed (that is, + * ::cuEventQuery() would return ::CUDA_ERROR_NOT_READY on at least one of the + * events), ::CUDA_ERROR_NOT_READY is returned. If either event was created with + * the ::CU_EVENT_DISABLE_TIMING flag, then this function will return + * ::CUDA_ERROR_INVALID_HANDLE. + * + * \param pMilliseconds - Time between \p hStart and \p hEnd in ms + * \param hStart - Starting event + * \param hEnd - Ending event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_READY + * \notefnerr + * + * \sa ::cuEventCreate, + * ::cuEventRecord, + * ::cuEventQuery, + * ::cuEventSynchronize, + * ::cuEventDestroy, + * ::cudaEventElapsedTime + */ +CUresult CUDAAPI cuEventElapsedTime(float *pMilliseconds, CUevent hStart, CUevent hEnd); + +/** @} */ /* END CUDA_EVENT */ + +/** + * \defgroup CUDA_EXTRES_INTEROP External Resource Interoperability + * + * ___MANBRIEF___ External resource interoperability functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the external resource interoperability functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + + /** + * \brief Imports an external memory object + * + * Imports an externally allocated memory object and returns + * a handle to that in \p extMem_out. + * + * The properties of the handle being imported must be described in + * \p memHandleDesc. The ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC structure + * is defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_MEMORY_HANDLE_DESC_st { + CUexternalMemoryHandleType type; + union { + int fd; + struct { + void *handle; + const void *name; + } win32; + const void *nvSciBufObject; + } handle; + unsigned long long size; + unsigned int flags; + } CUDA_EXTERNAL_MEMORY_HANDLE_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type specifies the type + * of handle being imported. ::CUexternalMemoryHandleType is + * defined as: + * + * \code + typedef enum CUexternalMemoryHandleType_enum { + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP = 4, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE = 5, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE = 6, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT = 7, + CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF = 8 + } CUexternalMemoryHandleType; + * \endcode + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_FD, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::fd must be a valid + * file descriptor referencing a memory object. Ownership of + * the file descriptor is transferred to the CUDA driver when the + * handle is imported successfully. Performing any operations on the + * file descriptor after it is imported results in undefined behavior. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32, then exactly one + * of ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be + * NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * references a memory object. Ownership of this handle is + * not transferred to CUDA after the import operation, so the + * application must release the handle using the appropriate system + * call. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * is not NULL, then it must point to a NULL-terminated array of + * UTF-16 characters that refers to a memory object. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_OPAQUE_WIN32_KMT, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must + * be non-NULL and + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * must be NULL. The handle specified must be a globally shared KMT + * handle. This handle does not hold a reference to the underlying + * object, and thus will be invalid when all references to the + * memory object are destroyed. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_HEAP, then exactly one + * of ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be + * NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * is returned by ID3D12Device::CreateSharedHandle when referring to a + * ID3D12Heap object. This handle holds a reference to the underlying + * object. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * is not NULL, then it must point to a NULL-terminated array of + * UTF-16 characters that refers to a ID3D12Heap object. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE, then exactly one + * of ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name must not be + * NULL. If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * is returned by ID3D12Device::CreateSharedHandle when referring to a + * ID3D12Resource object. This handle holds a reference to the + * underlying object. If + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * is not NULL, then it must point to a NULL-terminated array of + * UTF-16 characters that refers to a ID3D12Resource object. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must + * represent a valid shared NT handle that is returned by + * IDXGIResource1::CreateSharedHandle when referring to a + * ID3D11Resource object. If + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * is not NULL, then it must point to a NULL-terminated array of + * UTF-16 characters that refers to a ID3D11Resource object. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::handle must + * represent a valid shared KMT handle that is returned by + * IDXGIResource::GetSharedHandle when referring to a + * ID3D11Resource object and + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::win32::name + * must be NULL. + * + * If ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type is + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, then + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::handle::nvSciBufObject must be non-NULL + * and reference a valid NvSciBuf object. + * If the NvSciBuf object imported into CUDA is also mapped by other drivers, then the + * application must use ::cuWaitExternalSemaphoresAsync or ::cuSignalExternalSemaphoresAsync + * as appropriate barriers to maintain coherence between CUDA and the other drivers. + * See ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC and ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC + * for memory synchronization. + * + * + * The size of the memory object must be specified in + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::size. + * + * Specifying the flag ::CUDA_EXTERNAL_MEMORY_DEDICATED in + * ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::flags indicates that the + * resource is a dedicated resource. The definition of what a + * dedicated resource is outside the scope of this extension. + * This flag must be set if ::CUDA_EXTERNAL_MEMORY_HANDLE_DESC::type + * is one of the following: + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_D3D11_RESOURCE_KMT + * + * \param extMem_out - Returned handle to an external memory object + * \param memHandleDesc - Memory import handle descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \note If the Vulkan memory imported into CUDA is mapped on the CPU then the + * application must use vkInvalidateMappedMemoryRanges/vkFlushMappedMemoryRanges + * as well as appropriate Vulkan pipeline barriers to maintain coherence between + * CPU and GPU. For more information on these APIs, please refer to "Synchronization + * and Cache Control" chapter from Vulkan specification. + * + * \sa ::cuDestroyExternalMemory, + * ::cuExternalMemoryGetMappedBuffer, + * ::cuExternalMemoryGetMappedMipmappedArray + */ +CUresult CUDAAPI cuImportExternalMemory(CUexternalMemory *extMem_out, const CUDA_EXTERNAL_MEMORY_HANDLE_DESC *memHandleDesc); + +/** + * \brief Maps a buffer onto an imported memory object + * + * Maps a buffer onto an imported memory object and returns a device + * pointer in \p devPtr. + * + * The properties of the buffer being mapped must be described in + * \p bufferDesc. The ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC structure is + * defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_MEMORY_BUFFER_DESC_st { + unsigned long long offset; + unsigned long long size; + unsigned int flags; + } CUDA_EXTERNAL_MEMORY_BUFFER_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC::offset is the offset in + * the memory object where the buffer's base address is. + * ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC::size is the size of the buffer. + * ::CUDA_EXTERNAL_MEMORY_BUFFER_DESC::flags must be zero. + * + * The offset and size have to be suitably aligned to match the + * requirements of the external API. Mapping two buffers whose ranges + * overlap may or may not result in the same virtual address being + * returned for the overlapped portion. In such cases, the application + * must ensure that all accesses to that region from the GPU are + * volatile. Otherwise writes made via one address are not guaranteed + * to be visible via the other address, even if they're issued by the + * same thread. It is recommended that applications map the combined + * range instead of mapping separate buffers and then apply the + * appropriate offsets to the returned pointer to derive the + * individual buffers. + * + * The returned pointer \p devPtr must be freed using ::cuMemFree. + * + * \param devPtr - Returned device pointer to buffer + * \param extMem - Handle to external memory object + * \param bufferDesc - Buffer descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalMemory, + * ::cuDestroyExternalMemory, + * ::cuExternalMemoryGetMappedMipmappedArray + */ +CUresult CUDAAPI cuExternalMemoryGetMappedBuffer(CUdeviceptr *devPtr, CUexternalMemory extMem, const CUDA_EXTERNAL_MEMORY_BUFFER_DESC *bufferDesc); + +/** + * \brief Maps a CUDA mipmapped array onto an external memory object + * + * Maps a CUDA mipmapped array onto an external object and returns a + * handle to it in \p mipmap. + * + * The properties of the CUDA mipmapped array being mapped must be + * described in \p mipmapDesc. The structure + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC is defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC_st { + unsigned long long offset; + CUDA_ARRAY3D_DESCRIPTOR arrayDesc; + unsigned int numLevels; + } CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::offset is the + * offset in the memory object where the base level of the mipmap + * chain is. + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::arrayDesc describes + * the format, dimensions and type of the base level of the mipmap + * chain. For further details on these parameters, please refer to the + * documentation for ::cuMipmappedArrayCreate. Note that if the mipmapped + * array is bound as a color target in the graphics API, then the flag + * ::CUDA_ARRAY3D_COLOR_ATTACHMENT must be specified in + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::arrayDesc::Flags. + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::numLevels specifies + * the total number of levels in the mipmap chain. + * + * If \p extMem was imported from a handle of type ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF, then + * ::CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC::numLevels must be equal to 1. + * + * The returned CUDA mipmapped array must be freed using ::cuMipmappedArrayDestroy. + * + * \param mipmap - Returned CUDA mipmapped array + * \param extMem - Handle to external memory object + * \param mipmapDesc - CUDA array descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalMemory, + * ::cuDestroyExternalMemory, + * ::cuExternalMemoryGetMappedBuffer + */ +CUresult CUDAAPI cuExternalMemoryGetMappedMipmappedArray(CUmipmappedArray *mipmap, CUexternalMemory extMem, const CUDA_EXTERNAL_MEMORY_MIPMAPPED_ARRAY_DESC *mipmapDesc); + +/** + * \brief Destroys an external memory object. + * + * Destroys the specified external memory object. Any existing buffers + * and CUDA mipmapped arrays mapped onto this object must no longer be + * used and must be explicitly freed using ::cuMemFree and + * ::cuMipmappedArrayDestroy respectively. + * + * \param extMem - External memory object to be destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalMemory, + * ::cuExternalMemoryGetMappedBuffer, + * ::cuExternalMemoryGetMappedMipmappedArray + */ +CUresult CUDAAPI cuDestroyExternalMemory(CUexternalMemory extMem); + +/** + * \brief Imports an external semaphore + * + * Imports an externally allocated synchronization object and returns + * a handle to that in \p extSem_out. + * + * The properties of the handle being imported must be described in + * \p semHandleDesc. The ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC is + * defined as follows: + * + * \code + typedef struct CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC_st { + CUexternalSemaphoreHandleType type; + union { + int fd; + struct { + void *handle; + const void *name; + } win32; + const void* NvSciSyncObj; + } handle; + unsigned int flags; + } CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC; + * \endcode + * + * where ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type specifies the type of + * handle being imported. ::CUexternalSemaphoreHandleType is defined + * as: + * + * \code + typedef enum CUexternalSemaphoreHandleType_enum { + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD = 1, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32 = 2, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT = 3, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE = 4, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE = 5, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC = 6, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX = 7, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT = 8, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD = 9, + CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 = 10 + } CUexternalSemaphoreHandleType; + * \endcode + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::fd must be a valid + * file descriptor referencing a synchronization object. Ownership of + * the file descriptor is transferred to the CUDA driver when the + * handle is imported successfully. Performing any operations on the + * file descriptor after it is imported results in undefined behavior. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, then exactly one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * references a synchronization object. Ownership of this handle is + * not transferred to CUDA after the import operation, so the + * application must release the handle using the appropriate system + * call. If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle must + * be non-NULL and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * must be NULL. The handle specified must be a globally shared KMT + * handle. This handle does not hold a reference to the underlying + * object, and thus will be invalid when all references to the + * synchronization object are destroyed. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, then exactly one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * is returned by ID3D12Device::CreateSharedHandle when referring to a + * ID3D12Fence object. This handle holds a reference to the underlying + * object. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid ID3D12Fence object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared NT handle that is returned by + * ID3D11Fence::CreateSharedHandle. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid ID3D11Fence object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::nvSciSyncObj + * represents a valid NvSciSyncObj. + * + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared NT handle that + * is returned by IDXGIResource1::CreateSharedHandle when referring to + * a IDXGIKeyedMutex object. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object that + * refers to a valid IDXGIKeyedMutex object. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * represents a valid shared KMT handle that + * is returned by IDXGIResource::GetSharedHandle when referring to + * a IDXGIKeyedMutex object and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must be NULL. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, then + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::fd must be a valid + * file descriptor referencing a synchronization object. Ownership of + * the file descriptor is transferred to the CUDA driver when the + * handle is imported successfully. Performing any operations on the + * file descriptor after it is imported results in undefined behavior. + * + * If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::type is + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32, then exactly one + * of ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle and + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name must not be + * NULL. If + * ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::handle + * is not NULL, then it must represent a valid shared NT handle that + * references a synchronization object. Ownership of this handle is + * not transferred to CUDA after the import operation, so the + * application must release the handle using the appropriate system + * call. If ::CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC::handle::win32::name + * is not NULL, then it must name a valid synchronization object. + * + * \param extSem_out - Returned handle to an external semaphore + * \param semHandleDesc - Semaphore import handle descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuImportExternalSemaphore(CUexternalSemaphore *extSem_out, const CUDA_EXTERNAL_SEMAPHORE_HANDLE_DESC *semHandleDesc); + +/** + * \brief Signals a set of external semaphore objects + * + * Enqueues a signal operation on a set of externally allocated + * semaphore object in the specified stream. The operations will be + * executed when all prior operations in the stream complete. + * + * The exact semantics of signaling a semaphore depends on the type of + * the object. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * then signaling the semaphore will set it to the signaled state. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * then the semaphore will be set to the value specified in + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::fence::value. + * + * If the semaphore object is of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC + * this API sets ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence + * to a value that can be used by subsequent waiters of the same NvSciSync object + * to order operations with those currently submitted in \p stream. Such an update + * will overwrite previous contents of + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence. By default, + * signaling such an external semaphore object causes appropriate memory synchronization + * operations to be performed over all external memory objects that are imported as + * ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. This ensures that any subsequent accesses + * made by other importers of the same set of NvSciBuf memory object(s) are coherent. + * These operations can be skipped by specifying the flag + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_SKIP_NVSCIBUF_MEMSYNC, which can be used as a + * performance optimization when data coherency is not required. But specifying this + * flag in scenarios where data coherency is required results in undefined behavior. + * Also, for semaphore object of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, + * if the NvSciSyncAttrList used to create the NvSciSyncObj had not set the flags in + * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_SIGNAL, this API will return + * CUDA_ERROR_NOT_SUPPORTED. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then the keyed mutex will be released with the key specified in + * ::CUDA_EXTERNAL_SEMAPHORE_PARAMS::params::keyedmutex::key. + * + * \param extSemArray - Set of external semaphores to be signaled + * \param paramsArray - Array of semaphore parameters + * \param numExtSems - Number of semaphores to signal + * \param stream - Stream to enqueue the signal operations in + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuSignalExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); + +/** + * \brief Waits on a set of external semaphore objects + * + * Enqueues a wait operation on a set of externally allocated + * semaphore object in the specified stream. The operations will be + * executed when all prior operations in the stream complete. + * + * The exact semantics of waiting on a semaphore depends on the type + * of the object. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_OPAQUE_WIN32_KMT + * then waiting on the semaphore will wait until the semaphore reaches + * the signaled state. The semaphore will then be reset to the + * unsignaled state. Therefore for every signal operation, there can + * only be one wait operation. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D12_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_FENCE, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_FD, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_TIMELINE_SEMAPHORE_WIN32 + * then waiting on the semaphore will wait until the value of the + * semaphore is greater than or equal to + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::fence::value. + * + * If the semaphore object is of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC + * then, waiting on the semaphore will wait until the + * ::CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS::params::nvSciSync::fence is signaled by the + * signaler of the NvSciSyncObj that was associated with this semaphore object. + * By default, waiting on such an external semaphore object causes appropriate + * memory synchronization operations to be performed over all external memory objects + * that are imported as ::CU_EXTERNAL_MEMORY_HANDLE_TYPE_NVSCIBUF. This ensures that + * any subsequent accesses made by other importers of the same set of NvSciBuf memory + * object(s) are coherent. These operations can be skipped by specifying the flag + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_SKIP_NVSCIBUF_MEMSYNC, which can be used as a + * performance optimization when data coherency is not required. But specifying this + * flag in scenarios where data coherency is required results in undefined behavior. + * Also, for semaphore object of the type ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_NVSCISYNC, + * if the NvSciSyncAttrList used to create the NvSciSyncObj had not set the flags in + * ::cuDeviceGetNvSciSyncAttributes to CUDA_NVSCISYNC_ATTR_WAIT, this API will return + * CUDA_ERROR_NOT_SUPPORTED. + * + * If the semaphore object is any one of the following types: + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX, + * ::CU_EXTERNAL_SEMAPHORE_HANDLE_TYPE_D3D11_KEYED_MUTEX_KMT + * then the keyed mutex will be acquired when it is released with the key + * specified in ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::keyedmutex::key + * or until the timeout specified by + * ::CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS::params::keyedmutex::timeoutMs + * has lapsed. The timeout interval can either be a finite value + * specified in milliseconds or an infinite value. In case an infinite + * value is specified the timeout never elapses. The windows INFINITE + * macro must be used to specify infinite timeout. + * + * \param extSemArray - External semaphores to be waited on + * \param paramsArray - Array of semaphore parameters + * \param numExtSems - Number of semaphores to wait on + * \param stream - Stream to enqueue the wait operations in + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_TIMEOUT + * \notefnerr + * + * \sa ::cuImportExternalSemaphore, + * ::cuDestroyExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync + */ +CUresult CUDAAPI cuWaitExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); + +/** + * \brief Destroys an external semaphore + * + * Destroys an external semaphore object and releases any references + * to the underlying resource. Any outstanding signals or waits must + * have completed before the semaphore is destroyed. + * + * \param extSem - External semaphore to be destroyed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuDestroyExternalSemaphore(CUexternalSemaphore extSem); + +/** @} */ /* END CUDA_EXTRES_INTEROP */ + +/** + * \defgroup CUDA_MEMOP Stream memory operations + * + * ___MANBRIEF___ Stream memory operations of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the stream memory operations of the low-level CUDA + * driver application programming interface. + * + * The whole set of operations is disabled by default. Users are required + * to explicitly enable them, e.g. on Linux by passing the kernel module + * parameter shown below: + * modprobe nvidia NVreg_EnableStreamMemOPs=1 + * There is currently no way to enable these operations on other operating + * systems. + * + * Users can programmatically query whether the device supports these + * operations with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. + * + * Support for the ::CU_STREAM_WAIT_VALUE_NOR flag can be queried with + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR. + * + * Support for the ::cuStreamWriteValue64() and ::cuStreamWaitValue64() + * functions, as well as for the ::CU_STREAM_MEM_OP_WAIT_VALUE_64 and + * ::CU_STREAM_MEM_OP_WRITE_VALUE_64 flags, can be queried with + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. + * + * Support for both ::CU_STREAM_WAIT_VALUE_FLUSH and + * ::CU_STREAM_MEM_OP_FLUSH_REMOTE_WRITES requires dedicated platform + * hardware features and can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_FLUSH_REMOTE_WRITES. + * + * Note that all memory pointers passed as parameters to these operations + * are device pointers. Where necessary a device pointer should be + * obtained, for example with ::cuMemHostGetDevicePointer(). + * + * None of the operations accepts pointers to managed memory buffers + * (::cuMemAllocManaged). + * + * @{ + */ + +/** + * \brief Wait on a memory location + * + * Enqueues a synchronization of the stream on the given memory location. Work + * ordered after the operation will block until the given condition on the + * memory is satisfied. By default, the condition is to wait for + * (int32_t)(*addr - value) >= 0, a cyclic greater-or-equal. + * Other condition types can be specified via \p flags. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). This function cannot + * be used with managed memory (::cuMemAllocManaged). + * + * Support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. + * + * Support for CU_STREAM_WAIT_VALUE_NOR can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_WAIT_VALUE_NOR. + * + * \param stream The stream to synchronize on the memory location. + * \param addr The memory location to wait on. + * \param value The value to compare with the memory location. + * \param flags See ::CUstreamWaitValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWaitValue64, + * ::cuStreamWriteValue32, + * ::cuStreamWriteValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + +/** + * \brief Wait on a memory location + * + * Enqueues a synchronization of the stream on the given memory location. Work + * ordered after the operation will block until the given condition on the + * memory is satisfied. By default, the condition is to wait for + * (int64_t)(*addr - value) >= 0, a cyclic greater-or-equal. + * Other condition types can be specified via \p flags. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). + * + * Support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. + * + * \param stream The stream to synchronize on the memory location. + * \param addr The memory location to wait on. + * \param value The value to compare with the memory location. + * \param flags See ::CUstreamWaitValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWaitValue32, + * ::cuStreamWriteValue32, + * ::cuStreamWriteValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + +/** + * \brief Write a value to memory + * + * Write a value to memory. Unless the ::CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER + * flag is passed, the write is preceded by a system-wide memory fence, + * equivalent to a __threadfence_system() but scoped to the stream + * rather than a CUDA thread. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). This function cannot + * be used with managed memory (::cuMemAllocManaged). + * + * Support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. + * + * \param stream The stream to do the write in. + * \param addr The device address to write to. + * \param value The value to write. + * \param flags See ::CUstreamWriteValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWriteValue64, + * ::cuStreamWaitValue32, + * ::cuStreamWaitValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuEventRecord + */ +CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + +/** + * \brief Write a value to memory + * + * Write a value to memory. Unless the ::CU_STREAM_WRITE_VALUE_NO_MEMORY_BARRIER + * flag is passed, the write is preceded by a system-wide memory fence, + * equivalent to a __threadfence_system() but scoped to the stream + * rather than a CUDA thread. + * + * If the memory was registered via ::cuMemHostRegister(), the device pointer + * should be obtained with ::cuMemHostGetDevicePointer(). + * + * Support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_64_BIT_STREAM_MEM_OPS. + * + * \param stream The stream to do the write in. + * \param addr The device address to write to. + * \param value The value to write. + * \param flags See ::CUstreamWriteValue_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWriteValue32, + * ::cuStreamWaitValue32, + * ::cuStreamWaitValue64, + * ::cuStreamBatchMemOp, + * ::cuMemHostRegister, + * ::cuEventRecord + */ +CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + +/** + * \brief Batch operations to synchronize the stream via memory operations + * + * This is a batch version of ::cuStreamWaitValue32() and ::cuStreamWriteValue32(). + * Batching operations may avoid some performance overhead in both the API call + * and the device execution versus adding them to the stream in separate API + * calls. The operations are enqueued in the order they appear in the array. + * + * See ::CUstreamBatchMemOpType for the full set of supported operations, and + * ::cuStreamWaitValue32(), ::cuStreamWaitValue64(), ::cuStreamWriteValue32(), + * and ::cuStreamWriteValue64() for details of specific operations. + * + * Basic support for this can be queried with ::cuDeviceGetAttribute() and + * ::CU_DEVICE_ATTRIBUTE_CAN_USE_STREAM_MEM_OPS. See related APIs for details + * on querying support for specific operations. + * + * \param stream The stream to enqueue the operations in. + * \param count The number of operations in the array. Must be less than 256. + * \param paramArray The types and parameters of the individual operations. + * \param flags Reserved for future expansion; must be 0. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \notefnerr + * + * \sa ::cuStreamWaitValue32, + * ::cuStreamWaitValue64, + * ::cuStreamWriteValue32, + * ::cuStreamWriteValue64, + * ::cuMemHostRegister + */ +CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstreamBatchMemOpParams *paramArray, unsigned int flags); + +/** @} */ /* END CUDA_MEMOP */ + +/** + * \defgroup CUDA_EXEC Execution Control + * + * ___MANBRIEF___ execution control functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the execution control functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns information about a function + * + * Returns in \p *pi the integer value of the attribute \p attrib on the kernel + * given by \p hfunc. The supported attributes are: + * - ::CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK: The maximum number of threads + * per block, beyond which a launch of the function would fail. This number + * depends on both the function and the device on which the function is + * currently loaded. + * - ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES: The size in bytes of + * statically-allocated shared memory per block required by this function. + * This does not include dynamically-allocated shared memory requested by + * the user at runtime. + * - ::CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES: The size in bytes of user-allocated + * constant memory required by this function. + * - ::CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES: The size in bytes of local memory + * used by each thread of this function. + * - ::CU_FUNC_ATTRIBUTE_NUM_REGS: The number of registers used by each thread + * of this function. + * - ::CU_FUNC_ATTRIBUTE_PTX_VERSION: The PTX virtual architecture version for + * which the function was compiled. This value is the major PTX version * 10 + * + the minor PTX version, so a PTX version 1.3 function would return the + * value 13. Note that this may return the undefined value of 0 for cubins + * compiled prior to CUDA 3.0. + * - ::CU_FUNC_ATTRIBUTE_BINARY_VERSION: The binary architecture version for + * which the function was compiled. This value is the major binary + * version * 10 + the minor binary version, so a binary version 1.3 function + * would return the value 13. Note that this will return a value of 10 for + * legacy cubins that do not have a properly-encoded binary architecture + * version. + * - ::CU_FUNC_CACHE_MODE_CA: The attribute to indicate whether the function has + * been compiled with user specified option "-Xptxas --dlcm=ca" set . + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: The maximum size in bytes of + * dynamically-allocated shared memory. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: Preferred shared memory-L1 + * cache split ratio in percent of total shared memory. + * + * \param pi - Returned attribute value + * \param attrib - Attribute requested + * \param hfunc - Function to query attribute of + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuLaunchKernel, + * ::cudaFuncGetAttributes, + * ::cudaFuncSetAttribute + */ +CUresult CUDAAPI cuFuncGetAttribute(int *pi, CUfunction_attribute attrib, CUfunction hfunc); + +/** + * \brief Sets information about a function + * + * This call sets the value of a specified attribute \p attrib on the kernel given + * by \p hfunc to an integer value specified by \p val + * This function returns CUDA_SUCCESS if the new value of the attribute could be + * successfully set. If the set fails, this call will return an error. + * Not all attributes can have values set. Attempting to set a value on a read-only + * attribute will result in an error (CUDA_ERROR_INVALID_VALUE) + * + * Supported attributes for the cuFuncSetAttribute call are: + * - ::CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES: This maximum size in bytes of + * dynamically-allocated shared memory. The value should contain the requested + * maximum size of dynamically-allocated shared memory. The sum of this value and + * the function attribute ::CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES cannot exceed the + * device attribute ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN. + * The maximal size of requestable dynamic shared memory may differ by GPU + * architecture. + * - ::CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT: On devices where the L1 + * cache and shared memory use the same hardware resources, this sets the shared memory + * carveout preference, in percent of the total shared memory. + * See ::CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR + * This is only a hint, and the driver can choose a different ratio if required to execute the function. + * + * \param hfunc - Function to query attribute of + * \param attrib - Attribute requested + * \param value - The value to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuLaunchKernel, + * ::cudaFuncGetAttributes, + * ::cudaFuncSetAttribute + */ +CUresult CUDAAPI cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value); + +/** + * \brief Sets the preferred cache configuration for a device function + * + * On devices where the L1 cache and shared memory use the same hardware + * resources, this sets through \p config the preferred cache configuration for + * the device function \p hfunc. This is only a preference. The driver will use + * the requested configuration if possible, but it is free to choose a different + * configuration if required to execute \p hfunc. Any context-wide preference + * set via ::cuCtxSetCacheConfig() will be overridden by this per-function + * setting unless the per-function setting is ::CU_FUNC_CACHE_PREFER_NONE. In + * that case, the current context-wide setting will be used. + * + * This setting does nothing on devices where the size of the L1 cache and + * shared memory are fixed. + * + * Launching a kernel with a different preference than the most recent + * preference setting may insert a device-side synchronization point. + * + * + * The supported cache configurations are: + * - ::CU_FUNC_CACHE_PREFER_NONE: no preference for shared memory or L1 (default) + * - ::CU_FUNC_CACHE_PREFER_SHARED: prefer larger shared memory and smaller L1 cache + * - ::CU_FUNC_CACHE_PREFER_L1: prefer larger L1 cache and smaller shared memory + * - ::CU_FUNC_CACHE_PREFER_EQUAL: prefer equal sized L1 cache and shared memory + * + * \param hfunc - Kernel to configure cache for + * \param config - Requested cache configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchKernel, + * ::cudaFuncSetCacheConfig + */ +CUresult CUDAAPI cuFuncSetCacheConfig(CUfunction hfunc, CUfunc_cache config); + +/** + * \brief Sets the shared memory configuration for a device function. + * + * On devices with configurable shared memory banks, this function will + * force all subsequent launches of the specified device function to have + * the given shared memory bank size configuration. On any given launch of the + * function, the shared memory configuration of the device will be temporarily + * changed if needed to suit the function's preferred configuration. Changes in + * shared memory configuration between subsequent launches of functions, + * may introduce a device side synchronization point. + * + * Any per-function setting of shared memory bank size set via + * ::cuFuncSetSharedMemConfig will override the context wide setting set with + * ::cuCtxSetSharedMemConfig. + * + * Changing the shared memory bank size will not increase shared memory usage + * or affect occupancy of kernels, but may have major effects on performance. + * Larger bank sizes will allow for greater potential bandwidth to shared memory, + * but will change what kinds of accesses to shared memory will result in bank + * conflicts. + * + * This function will do nothing on devices with fixed shared memory bank size. + * + * The supported bank configurations are: + * - ::CU_SHARED_MEM_CONFIG_DEFAULT_BANK_SIZE: use the context's shared memory + * configuration when launching this function. + * - ::CU_SHARED_MEM_CONFIG_FOUR_BYTE_BANK_SIZE: set shared memory bank width to + * be natively four bytes when launching this function. + * - ::CU_SHARED_MEM_CONFIG_EIGHT_BYTE_BANK_SIZE: set shared memory bank width to + * be natively eight bytes when launching this function. + * + * \param hfunc - kernel to be given a shared memory config + * \param config - requested shared memory configuration + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuCtxGetSharedMemConfig, + * ::cuCtxSetSharedMemConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchKernel, + * ::cudaFuncSetSharedMemConfig + */ +CUresult CUDAAPI cuFuncSetSharedMemConfig(CUfunction hfunc, CUsharedconfig config); + +/** + * \brief Returns a module handle + * + * Returns in \p *hmod the handle of the module that function \p hfunc + * is located in. The lifetime of the module corresponds to the lifetime of + * the context it was loaded in or until the module is explicitly unloaded. + * + * The CUDA runtime manages its own modules loaded into the primary context. + * If the handle returned by this API refers to a module loaded by the CUDA runtime, + * calling ::cuModuleUnload() on that module will result in undefined behavior. + * + * \param hmod - Returned module handle + * \param hfunc - Function to retrieve module for + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_FOUND + * \notefnerr + * + */ +CUresult CUDAAPI cuFuncGetModule(CUmodule *hmod, CUfunction hfunc); + +/** + * \brief Launches a CUDA function + * + * Invokes the kernel \p f on a \p gridDimX x \p gridDimY x \p gridDimZ + * grid of blocks. Each block contains \p blockDimX x \p blockDimY x + * \p blockDimZ threads. + * + * \p sharedMemBytes sets the amount of dynamic shared memory that will be + * available to each thread block. + * + * Kernel parameters to \p f can be specified in one of two ways: + * + * 1) Kernel parameters can be specified via \p kernelParams. If \p f + * has N parameters, then \p kernelParams needs to be an array of N + * pointers. Each of \p kernelParams[0] through \p kernelParams[N-1] + * must point to a region of memory from which the actual kernel + * parameter will be copied. The number of kernel parameters and their + * offsets and sizes do not need to be specified as that information is + * retrieved directly from the kernel's image. + * + * 2) Kernel parameters can also be packaged by the application into + * a single buffer that is passed in via the \p extra parameter. + * This places the burden on the application of knowing each kernel + * parameter's size and alignment/padding within the buffer. Here is + * an example of using the \p extra parameter in this manner: + * \code + size_t argBufferSize; + char argBuffer[256]; + + // populate argBuffer and argBufferSize + + void *config[] = { + CU_LAUNCH_PARAM_BUFFER_POINTER, argBuffer, + CU_LAUNCH_PARAM_BUFFER_SIZE, &argBufferSize, + CU_LAUNCH_PARAM_END + }; + status = cuLaunchKernel(f, gx, gy, gz, bx, by, bz, sh, s, NULL, config); + * \endcode + * + * The \p extra parameter exists to allow ::cuLaunchKernel to take + * additional less commonly used arguments. \p extra specifies a list of + * names of extra settings and their corresponding values. Each extra + * setting name is immediately followed by the corresponding value. The + * list must be terminated with either NULL or ::CU_LAUNCH_PARAM_END. + * + * - ::CU_LAUNCH_PARAM_END, which indicates the end of the \p extra + * array; + * - ::CU_LAUNCH_PARAM_BUFFER_POINTER, which specifies that the next + * value in \p extra will be a pointer to a buffer containing all + * the kernel parameters for launching kernel \p f; + * - ::CU_LAUNCH_PARAM_BUFFER_SIZE, which specifies that the next + * value in \p extra will be a pointer to a size_t containing the + * size of the buffer specified with ::CU_LAUNCH_PARAM_BUFFER_POINTER; + * + * The error ::CUDA_ERROR_INVALID_VALUE will be returned if kernel + * parameters are specified with both \p kernelParams and \p extra + * (i.e. both \p kernelParams and \p extra are non-NULL). + * + * Calling ::cuLaunchKernel() invalidates the persistent function state + * set through the following deprecated APIs: + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), + * ::cuParamSetv(). + * + * Note that to use ::cuLaunchKernel(), the kernel \p f must either have + * been compiled with toolchain version 3.2 or later so that it will + * contain kernel parameter information, or have no kernel parameters. + * If either of these conditions is not met, then ::cuLaunchKernel() will + * return ::CUDA_ERROR_INVALID_IMAGE. + * + * \param f - Kernel to launch + * \param gridDimX - Width of grid in blocks + * \param gridDimY - Height of grid in blocks + * \param gridDimZ - Depth of grid in blocks + * \param blockDimX - X dimension of each thread block + * \param blockDimY - Y dimension of each thread block + * \param blockDimZ - Z dimension of each thread block + * \param sharedMemBytes - Dynamic shared-memory size per thread block in bytes + * \param hStream - Stream identifier + * \param kernelParams - Array of pointers to kernel parameters + * \param extra - Extra options + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cudaLaunchKernel + */ +CUresult CUDAAPI cuLaunchKernel(CUfunction f, + unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + CUstream hStream, + void **kernelParams, + void **extra); + + + + + + + + +/** + * \brief Launches a CUDA function where thread blocks can cooperate and synchronize as they execute + * + * Invokes the kernel \p f on a \p gridDimX x \p gridDimY x \p gridDimZ + * grid of blocks. Each block contains \p blockDimX x \p blockDimY x + * \p blockDimZ threads. + * + * \p sharedMemBytes sets the amount of dynamic shared memory that will be + * available to each thread block. + * + * The device on which this kernel is invoked must have a non-zero value for + * the device attribute ::CU_DEVICE_ATTRIBUTE_COOPERATIVE_LAUNCH. + * + * The total number of blocks launched cannot exceed the maximum number of blocks per + * multiprocessor as returned by ::cuOccupancyMaxActiveBlocksPerMultiprocessor (or + * ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags) times the number of multiprocessors + * as specified by the device attribute ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT. + * + * The kernel cannot make use of CUDA dynamic parallelism. + * + * Kernel parameters must be specified via \p kernelParams. If \p f + * has N parameters, then \p kernelParams needs to be an array of N + * pointers. Each of \p kernelParams[0] through \p kernelParams[N-1] + * must point to a region of memory from which the actual kernel + * parameter will be copied. The number of kernel parameters and their + * offsets and sizes do not need to be specified as that information is + * retrieved directly from the kernel's image. + * + * Calling ::cuLaunchCooperativeKernel() sets persistent function state that is + * the same as function state set through ::cuLaunchKernel API + * + * When the kernel \p f is launched via ::cuLaunchCooperativeKernel(), the previous + * block shape, shared size and parameter info associated with \p f + * is overwritten. + * + * Note that to use ::cuLaunchCooperativeKernel(), the kernel \p f must either have + * been compiled with toolchain version 3.2 or later so that it will + * contain kernel parameter information, or have no kernel parameters. + * If either of these conditions is not met, then ::cuLaunchCooperativeKernel() will + * return ::CUDA_ERROR_INVALID_IMAGE. + * + * \param f - Kernel to launch + * \param gridDimX - Width of grid in blocks + * \param gridDimY - Height of grid in blocks + * \param gridDimZ - Depth of grid in blocks + * \param blockDimX - X dimension of each thread block + * \param blockDimY - Y dimension of each thread block + * \param blockDimZ - Z dimension of each thread block + * \param sharedMemBytes - Dynamic shared-memory size per thread block in bytes + * \param hStream - Stream identifier + * \param kernelParams - Array of pointers to kernel parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchCooperativeKernelMultiDevice, + * ::cudaLaunchCooperativeKernel + */ +CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, + unsigned int gridDimX, + unsigned int gridDimY, + unsigned int gridDimZ, + unsigned int blockDimX, + unsigned int blockDimY, + unsigned int blockDimZ, + unsigned int sharedMemBytes, + CUstream hStream, + void **kernelParams); + +/** + * \brief Launches CUDA functions on multiple devices where thread blocks can cooperate and synchronize as they execute + * + * \deprecated This function is deprecated as of CUDA 11.3. + * + * Invokes kernels as specified in the \p launchParamsList array where each element + * of the array specifies all the parameters required to perform a single kernel launch. + * These kernels can cooperate and synchronize as they execute. The size of the array is + * specified by \p numDevices. + * + * No two kernels can be launched on the same device. All the devices targeted by this + * multi-device launch must be identical. All devices must have a non-zero value for the + * device attribute ::CU_DEVICE_ATTRIBUTE_COOPERATIVE_MULTI_DEVICE_LAUNCH. + * + * All kernels launched must be identical with respect to the compiled code. Note that + * any __device__, __constant__ or __managed__ variables present in the module that owns + * the kernel launched on each device, are independently instantiated on every device. + * It is the application's responsiblity to ensure these variables are initialized and + * used appropriately. + * + * The size of the grids as specified in blocks, the size of the blocks themselves + * and the amount of shared memory used by each thread block must also match across + * all launched kernels. + * + * The streams used to launch these kernels must have been created via either ::cuStreamCreate + * or ::cuStreamCreateWithPriority. The NULL stream or ::CU_STREAM_LEGACY or ::CU_STREAM_PER_THREAD + * cannot be used. + * + * The total number of blocks launched per kernel cannot exceed the maximum number of blocks + * per multiprocessor as returned by ::cuOccupancyMaxActiveBlocksPerMultiprocessor (or + * ::cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags) times the number of multiprocessors + * as specified by the device attribute ::CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT. Since the + * total number of blocks launched per device has to match across all devices, the maximum + * number of blocks that can be launched per device will be limited by the device with the + * least number of multiprocessors. + * + * The kernels cannot make use of CUDA dynamic parallelism. + * + * The ::CUDA_LAUNCH_PARAMS structure is defined as: + * \code + typedef struct CUDA_LAUNCH_PARAMS_st + { + CUfunction function; + unsigned int gridDimX; + unsigned int gridDimY; + unsigned int gridDimZ; + unsigned int blockDimX; + unsigned int blockDimY; + unsigned int blockDimZ; + unsigned int sharedMemBytes; + CUstream hStream; + void **kernelParams; + } CUDA_LAUNCH_PARAMS; + * \endcode + * where: + * - ::CUDA_LAUNCH_PARAMS::function specifies the kernel to be launched. All functions must + * be identical with respect to the compiled code. + * - ::CUDA_LAUNCH_PARAMS::gridDimX is the width of the grid in blocks. This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::gridDimY is the height of the grid in blocks. This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::gridDimZ is the depth of the grid in blocks. This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::blockDimX is the X dimension of each thread block. This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::blockDimX is the Y dimension of each thread block. This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::blockDimZ is the Z dimension of each thread block. This must match across + * all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::sharedMemBytes is the dynamic shared-memory size per thread block in bytes. + * This must match across all kernels launched. + * - ::CUDA_LAUNCH_PARAMS::hStream is the handle to the stream to perform the launch in. This cannot + * be the NULL stream or ::CU_STREAM_LEGACY or ::CU_STREAM_PER_THREAD. The CUDA context associated + * with this stream must match that associated with ::CUDA_LAUNCH_PARAMS::function. + * - ::CUDA_LAUNCH_PARAMS::kernelParams is an array of pointers to kernel parameters. If + * ::CUDA_LAUNCH_PARAMS::function has N parameters, then ::CUDA_LAUNCH_PARAMS::kernelParams + * needs to be an array of N pointers. Each of ::CUDA_LAUNCH_PARAMS::kernelParams[0] through + * ::CUDA_LAUNCH_PARAMS::kernelParams[N-1] must point to a region of memory from which the actual + * kernel parameter will be copied. The number of kernel parameters and their offsets and sizes + * do not need to be specified as that information is retrieved directly from the kernel's image. + * + * By default, the kernel won't begin execution on any GPU until all prior work in all the specified + * streams has completed. This behavior can be overridden by specifying the flag + * ::CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_PRE_LAUNCH_SYNC. When this flag is specified, each kernel + * will only wait for prior work in the stream corresponding to that GPU to complete before it begins + * execution. + * + * Similarly, by default, any subsequent work pushed in any of the specified streams will not begin + * execution until the kernels on all GPUs have completed. This behavior can be overridden by specifying + * the flag ::CUDA_COOPERATIVE_LAUNCH_MULTI_DEVICE_NO_POST_LAUNCH_SYNC. When this flag is specified, + * any subsequent work pushed in any of the specified streams will only wait for the kernel launched + * on the GPU corresponding to that stream to complete before it begins execution. + * + * Calling ::cuLaunchCooperativeKernelMultiDevice() sets persistent function state that is + * the same as function state set through ::cuLaunchKernel API when called individually for each + * element in \p launchParamsList. + * + * When kernels are launched via ::cuLaunchCooperativeKernelMultiDevice(), the previous + * block shape, shared size and parameter info associated with each ::CUDA_LAUNCH_PARAMS::function + * in \p launchParamsList is overwritten. + * + * Note that to use ::cuLaunchCooperativeKernelMultiDevice(), the kernels must either have + * been compiled with toolchain version 3.2 or later so that it will + * contain kernel parameter information, or have no kernel parameters. + * If either of these conditions is not met, then ::cuLaunchCooperativeKernelMultiDevice() will + * return ::CUDA_ERROR_INVALID_IMAGE. + * + * \param launchParamsList - List of launch parameters, one per device + * \param numDevices - Size of the \p launchParamsList array + * \param flags - Flags to control launch behavior + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_IMAGE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_COOPERATIVE_LAUNCH_TOO_LARGE, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \note_null_stream + * \notefnerr + * + * \sa ::cuCtxGetCacheConfig, + * ::cuCtxSetCacheConfig, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuLaunchCooperativeKernel, + * ::cudaLaunchCooperativeKernelMultiDevice + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchCooperativeKernelMultiDevice(CUDA_LAUNCH_PARAMS *launchParamsList, unsigned int numDevices, unsigned int flags); + +/** + * \brief Enqueues a host function call in a stream + * + * Enqueues a host function to run in a stream. The function will be called + * after currently enqueued work and will block work added after it. + * + * The host function must not make any CUDA API calls. Attempting to use a + * CUDA API may result in ::CUDA_ERROR_NOT_PERMITTED, but this is not required. + * The host function must not perform any synchronization that may depend on + * outstanding CUDA work not mandated to run earlier. Host functions without a + * mandated order (such as in independent streams) execute in undefined order + * and may be serialized. + * + * For the purposes of Unified Memory, execution makes a number of guarantees: + *
    + *
  • The stream is considered idle for the duration of the function's + * execution. Thus, for example, the function may always use memory attached + * to the stream it was enqueued in.
  • + *
  • The start of execution of the function has the same effect as + * synchronizing an event recorded in the same stream immediately prior to + * the function. It thus synchronizes streams which have been "joined" + * prior to the function.
  • + *
  • Adding device work to any stream does not have the effect of making + * the stream active until all preceding host functions and stream callbacks + * have executed. Thus, for + * example, a function might use global attached memory even if work has + * been added to another stream, if the work has been ordered behind the + * function call with an event.
  • + *
  • Completion of the function does not cause a stream to become + * active except as described above. The stream will remain idle + * if no device work follows the function, and will remain idle across + * consecutive host functions or stream callbacks without device work in + * between. Thus, for example, + * stream synchronization can be done by signaling from a host function at the + * end of the stream.
  • + *
+ * + * Note that, in contrast to ::cuStreamAddCallback, the function will not be + * called in the event of an error in the CUDA context. + * + * \param hStream - Stream to enqueue function call in + * \param fn - The function to call once preceding stream operations are complete + * \param userData - User-specified data to be passed to the function + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_SUPPORTED + * \note_null_stream + * \notefnerr + * + * \sa ::cuStreamCreate, + * ::cuStreamQuery, + * ::cuStreamSynchronize, + * ::cuStreamWaitEvent, + * ::cuStreamDestroy, + * ::cuMemAllocManaged, + * ::cuStreamAttachMemAsync, + * ::cuStreamAddCallback + */ +CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, void *userData); + +/** @} */ /* END CUDA_EXEC */ + +/** + * \defgroup CUDA_EXEC_DEPRECATED Execution Control [DEPRECATED] + * + * ___MANBRIEF___ deprecated execution control functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated execution control functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Sets the block-dimensions for the function + * + * \deprecated + * + * Specifies the \p x, \p y, and \p z dimensions of the thread blocks that are + * created when the kernel given by \p hfunc is launched. + * + * \param hfunc - Kernel to specify dimensions of + * \param x - X dimension + * \param y - Y dimension + * \param z - Z dimension + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetSharedSize, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSeti, + * ::cuParamSetf, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetBlockShape(CUfunction hfunc, int x, int y, int z); + +/** + * \brief Sets the dynamic shared-memory size for the function + * + * \deprecated + * + * Sets through \p bytes the amount of dynamic shared memory that will be + * available to each thread block when the kernel given by \p hfunc is launched. + * + * \param hfunc - Kernel to specify dynamic shared-memory size for + * \param bytes - Dynamic shared-memory size per thread in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetCacheConfig, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSeti, + * ::cuParamSetf, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuFuncSetSharedSize(CUfunction hfunc, unsigned int bytes); + +/** + * \brief Sets the parameter size for the function + * + * \deprecated + * + * Sets through \p numbytes the total size in bytes needed by the function + * parameters of the kernel corresponding to \p hfunc. + * + * \param hfunc - Kernel to set parameter size for + * \param numbytes - Size of parameter list in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetSize(CUfunction hfunc, unsigned int numbytes); + +/** + * \brief Adds an integer parameter to the function's argument list + * + * \deprecated + * + * Sets an integer parameter that will be specified the next time the + * kernel corresponding to \p hfunc will be invoked. \p offset is a byte offset. + * + * \param hfunc - Kernel to add parameter to + * \param offset - Offset to add parameter to argument list + * \param value - Value of parameter + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSeti(CUfunction hfunc, int offset, unsigned int value); + +/** + * \brief Adds a floating-point parameter to the function's argument list + * + * \deprecated + * + * Sets a floating-point parameter that will be specified the next time the + * kernel corresponding to \p hfunc will be invoked. \p offset is a byte offset. + * + * \param hfunc - Kernel to add parameter to + * \param offset - Offset to add parameter to argument list + * \param value - Value of parameter + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetf(CUfunction hfunc, int offset, float value); + +/** + * \brief Adds arbitrary data to the function's argument list + * + * \deprecated + * + * Copies an arbitrary amount of data (specified in \p numbytes) from \p ptr + * into the parameter space of the kernel corresponding to \p hfunc. \p offset + * is a byte offset. + * + * \param hfunc - Kernel to add data to + * \param offset - Offset to add data to argument list + * \param ptr - Pointer to arbitrary data + * \param numbytes - Size of data to copy in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetv(CUfunction hfunc, int offset, void *ptr, unsigned int numbytes); + +/** + * \brief Launches a CUDA function + * + * \deprecated + * + * Invokes the kernel \p f on a 1 x 1 x 1 grid of blocks. The block + * contains the number of threads specified by a previous call to + * ::cuFuncSetBlockShape(). + * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * + * \param f - Kernel to launch + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunchGrid, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunch(CUfunction f); + +/** + * \brief Launches a CUDA function + * + * \deprecated + * + * Invokes the kernel \p f on a \p grid_width x \p grid_height grid of + * blocks. Each block contains the number of threads specified by a previous + * call to ::cuFuncSetBlockShape(). + * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * + * \param f - Kernel to launch + * \param grid_width - Width of grid in blocks + * \param grid_height - Height of grid in blocks + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGridAsync, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGrid(CUfunction f, int grid_width, int grid_height); + +/** + * \brief Launches a CUDA function + * + * \deprecated + * + * Invokes the kernel \p f on a \p grid_width x \p grid_height grid of + * blocks. Each block contains the number of threads specified by a previous + * call to ::cuFuncSetBlockShape(). + * + * The block shape, dynamic shared memory size, and parameter information + * must be set using + * ::cuFuncSetBlockShape(), + * ::cuFuncSetSharedSize(), + * ::cuParamSetSize(), + * ::cuParamSeti(), + * ::cuParamSetf(), and + * ::cuParamSetv() + * prior to calling this function. + * + * Launching a function via ::cuLaunchKernel() invalidates the function's + * block shape, dynamic shared memory size, and parameter information. After + * launching via cuLaunchKernel, this state must be re-initialized prior to + * calling this function. Failure to do so results in undefined behavior. + * + * \param f - Kernel to launch + * \param grid_width - Width of grid in blocks + * \param grid_height - Height of grid in blocks + * \param hStream - Stream identifier + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_LAUNCH_FAILED, + * ::CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES, + * ::CUDA_ERROR_LAUNCH_TIMEOUT, + * ::CUDA_ERROR_LAUNCH_INCOMPATIBLE_TEXTURING, + * ::CUDA_ERROR_SHARED_OBJECT_INIT_FAILED + * + * \note In certain cases where cubins are created with no ABI (i.e., using \p ptxas \p --abi-compile \p no), + * this function may serialize kernel launches. The CUDA driver retains asynchronous behavior by + * growing the per-thread stack as needed per launch and not shrinking it afterwards. + * + * \note_null_stream + * \notefnerr + * + * \sa ::cuFuncSetBlockShape, + * ::cuFuncSetSharedSize, + * ::cuFuncGetAttribute, + * ::cuParamSetSize, + * ::cuParamSetf, + * ::cuParamSeti, + * ::cuParamSetv, + * ::cuLaunch, + * ::cuLaunchGrid, + * ::cuLaunchKernel + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuLaunchGridAsync(CUfunction f, int grid_width, int grid_height, CUstream hStream); + + +/** + * \brief Adds a texture-reference to the function's argument list + * + * \deprecated + * + * Makes the CUDA array or linear memory bound to the texture reference + * \p hTexRef available to a device program as a texture. In this version of + * CUDA, the texture-reference must be obtained via ::cuModuleGetTexRef() and + * the \p texunit parameter must be set to ::CU_PARAM_TR_DEFAULT. + * + * \param hfunc - Kernel to add texture-reference to + * \param texunit - Texture unit (must be ::CU_PARAM_TR_DEFAULT) + * \param hTexRef - Texture-reference to add to argument list + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuParamSetTexRef(CUfunction hfunc, int texunit, CUtexref hTexRef); +/** @} */ /* END CUDA_EXEC_DEPRECATED */ + +/** + * \defgroup CUDA_GRAPH Graph Management + * + * ___MANBRIEF___ graph management functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the graph management functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Creates a graph + * + * Creates an empty graph, which is returned via \p phGraph. + * + * \param phGraph - Returns newly created graph + * \param flags - Graph creation flags, must be 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + * ::cuGraphInstantiate, + * ::cuGraphDestroy, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphClone + */ +CUresult CUDAAPI cuGraphCreate(CUgraph *phGraph, unsigned int flags); + +/** + * \brief Creates a kernel execution node and adds it to a graph + * + * Creates a new kernel execution node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * The CUDA_KERNEL_NODE_PARAMS structure is defined as: + * + * \code + * typedef struct CUDA_KERNEL_NODE_PARAMS_st { + * CUfunction func; + * unsigned int gridDimX; + * unsigned int gridDimY; + * unsigned int gridDimZ; + * unsigned int blockDimX; + * unsigned int blockDimY; + * unsigned int blockDimZ; + * unsigned int sharedMemBytes; + * void **kernelParams; + * void **extra; + * } CUDA_KERNEL_NODE_PARAMS; + * \endcode + * + * When the graph is launched, the node will invoke kernel \p func on a (\p gridDimX x + * \p gridDimY x \p gridDimZ) grid of blocks. Each block contains + * (\p blockDimX x \p blockDimY x \p blockDimZ) threads. + * + * \p sharedMemBytes sets the amount of dynamic shared memory that will be + * available to each thread block. + * + * Kernel parameters to \p func can be specified in one of two ways: + * + * 1) Kernel parameters can be specified via \p kernelParams. If the kernel has N + * parameters, then \p kernelParams needs to be an array of N pointers. Each pointer, + * from \p kernelParams[0] to \p kernelParams[N-1], points to the region of memory from which the actual + * parameter will be copied. The number of kernel parameters and their offsets and sizes do not need + * to be specified as that information is retrieved directly from the kernel's image. + * + * 2) Kernel parameters for non-cooperative kernels can also be packaged by the application into a single + * buffer that is passed in via \p extra. This places the burden on the application of knowing each + * kernel parameter's size and alignment/padding within the buffer. The \p extra parameter exists + * to allow this function to take additional less commonly used arguments. \p extra specifies + * a list of names of extra settings and their corresponding values. Each extra setting name is + * immediately followed by the corresponding value. The list must be terminated with either NULL or + * CU_LAUNCH_PARAM_END. + * + * - ::CU_LAUNCH_PARAM_END, which indicates the end of the \p extra + * array; + * - ::CU_LAUNCH_PARAM_BUFFER_POINTER, which specifies that the next + * value in \p extra will be a pointer to a buffer + * containing all the kernel parameters for launching kernel + * \p func; + * - ::CU_LAUNCH_PARAM_BUFFER_SIZE, which specifies that the next + * value in \p extra will be a pointer to a size_t + * containing the size of the buffer specified with + * ::CU_LAUNCH_PARAM_BUFFER_POINTER; + * + * The error ::CUDA_ERROR_INVALID_VALUE will be returned if kernel parameters are specified with both + * \p kernelParams and \p extra (i.e. both \p kernelParams and \p extra are non-NULL). + * ::CUDA_ERROR_INVALID_VALUE will be returned if \p extra is used for a cooperative kernel. + * + * The \p kernelParams or \p extra array, as well as the argument values it points to, + * are copied during this call. + * + * \note Kernels launched using graphs must not use texture and surface references. Reading or + * writing through any texture or surface reference is undefined behavior. + * This restriction does not apply to texture and surface objects. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the GPU execution node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuLaunchCooperativeKernel, + * ::cuGraphKernelNodeGetParams, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddKernelNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a kernel node's parameters + * + * Returns the parameters of kernel node \p hNode in \p nodeParams. + * The \p kernelParams or \p extra array returned in \p nodeParams, + * as well as the argument values it points to, are owned by the node. + * This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphKernelNodeSetParams to update the + * parameters of this node. + * + * The params will contain either \p kernelParams or \p extra, + * according to which of these was most recently set on the node. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeSetParams + */ +CUresult CUDAAPI cuGraphKernelNodeGetParams(CUgraphNode hNode, CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets a kernel node's parameters + * + * Sets the parameters of kernel node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeGetParams + */ +CUresult CUDAAPI cuGraphKernelNodeSetParams(CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a memcpy node and adds it to a graph + * + * Creates a new memcpy node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * When the graph is launched, the node will perform the memcpy described by \p copyParams. + * See ::cuMemcpy3D() for a description of the structure and its restrictions. + * + * Memcpy nodes have some additional restrictions with regards to managed memory, if the + * system contains at least one device which has a zero value for the device attribute + * ::CU_DEVICE_ATTRIBUTE_CONCURRENT_MANAGED_ACCESS. If one or more of the operands refer + * to managed memory, then using the memory type ::CU_MEMORYTYPE_UNIFIED is disallowed + * for those operand(s). The managed memory will be treated as residing on either the + * host or the device, depending on which memory type is specified. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param copyParams - Parameters for the memory copy + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemcpy3D, + * ::cuGraphMemcpyNodeGetParams, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemcpyNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_MEMCPY3D *copyParams, CUcontext ctx); + +/** + * \brief Returns a memcpy node's parameters + * + * Returns the parameters of memcpy node \p hNode in \p nodeParams. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemcpy3D, + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeSetParams + */ +CUresult CUDAAPI cuGraphMemcpyNodeGetParams(CUgraphNode hNode, CUDA_MEMCPY3D *nodeParams); + +/** + * \brief Sets a memcpy node's parameters + * + * Sets the parameters of memcpy node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemcpy3D, + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeGetParams + */ +CUresult CUDAAPI cuGraphMemcpyNodeSetParams(CUgraphNode hNode, const CUDA_MEMCPY3D *nodeParams); + +/** + * \brief Creates a memset node and adds it to a graph + * + * Creates a new memset node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * The element size must be 1, 2, or 4 bytes. + * When the graph is launched, the node will perform the memset described by \p memsetParams. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param memsetParams - Parameters for the memory set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_CONTEXT + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemsetD2D32, + * ::cuGraphMemsetNodeGetParams, + * ::cuGraphMemsetNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode + */ +CUresult CUDAAPI cuGraphAddMemsetNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx); + +/** + * \brief Returns a memset node's parameters + * + * Returns the parameters of memset node \p hNode in \p nodeParams. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemsetD2D32, + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeSetParams + */ +CUresult CUDAAPI cuGraphMemsetNodeGetParams(CUgraphNode hNode, CUDA_MEMSET_NODE_PARAMS *nodeParams); + +/** + * \brief Sets a memset node's parameters + * + * Sets the parameters of memset node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuMemsetD2D32, + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeGetParams + */ +CUresult CUDAAPI cuGraphMemsetNodeSetParams(CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a host execution node and adds it to a graph + * + * Creates a new CPU execution node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * When the graph is launched, the node will invoke the specified CPU function. + * Host nodes are not supported under MPS with pre-Volta GPUs. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the host node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchHostFunc, + * ::cuGraphHostNodeGetParams, + * ::cuGraphHostNodeSetParams, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddHostNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a host node's parameters + * + * Returns the parameters of host node \p hNode in \p nodeParams. + * + * \param hNode - Node to get the parameters for + * \param nodeParams - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchHostFunc, + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeSetParams + */ +CUresult CUDAAPI cuGraphHostNodeGetParams(CUgraphNode hNode, CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Sets a host node's parameters + * + * Sets the parameters of host node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchHostFunc, + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeGetParams + */ +CUresult CUDAAPI cuGraphHostNodeSetParams(CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Creates a child graph node and adds it to a graph + * + * Creates a new node which executes an embedded graph, and adds it to \p hGraph with + * \p numDependencies dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * If \p hGraph contains allocation or free nodes, this call will return an error. + * + * The node executes an embedded child graph. The child graph is cloned in this call. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param childGraph - The graph to clone into this node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + * ::cuGraphClone + */ +CUresult CUDAAPI cuGraphAddChildGraphNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUgraph childGraph); + +/** + * \brief Gets a handle to the embedded graph of a child graph node + * + * Gets a handle to the embedded graph in a child graph node. This call + * does not clone the graph. Changes to the graph will be reflected in + * the node, and the node retains ownership of the graph. + * + * Allocation and free nodes cannot be added to the returned graph. + * Attempting to do so will return an error. + * + * \param hNode - Node to get the embedded graph for + * \param phGraph - Location to store a handle to the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphNodeFindInClone + */ +CUresult CUDAAPI cuGraphChildGraphNodeGetGraph(CUgraphNode hNode, CUgraph *phGraph); + +/** + * \brief Creates an empty node and adds it to a graph + * + * Creates a new node which performs no operation, and adds it to \p hGraph with + * \p numDependencies dependencies specified via \p dependencies. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * An empty node performs no operation during execution, but can be used for + * transitive ordering. For example, a phased execution graph with 2 groups of n + * nodes with a barrier between them can be represented using an empty node and + * 2*n dependency edges, rather than no empty node and n^2 dependency edges. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddEmptyNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies); + +/** + * \brief Creates an event record node and adds it to a graph + * + * Creates a new event record node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and event specified in \p event. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * Each launch of the graph will record \p event to capture execution of the + * node's dependencies. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param event - Event for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddEventRecordNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUevent event); + +/** + * \brief Returns the event associated with an event record node + * + * Returns the event of event record node \p hNode in \p event_out. + * + * \param hNode - Node to get the event for + * \param event_out - Pointer to return the event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventRecordNodeGetEvent(CUgraphNode hNode, CUevent *event_out); + +/** + * \brief Sets an event record node's event + * + * Sets the event of event record node \p hNode to \p event. + * + * \param hNode - Node to set the event for + * \param event - Event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventRecordNodeSetEvent(CUgraphNode hNode, CUevent event); + +/** + * \brief Creates an event wait node and adds it to a graph + * + * Creates a new event wait node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and event specified in \p event. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. + * A handle to the new node will be returned in \p phGraphNode. + * + * The graph node will wait for all work captured in \p event. See ::cuEventRecord() + * for details on what is captured by an event. \p event may be from a different context + * or device than the launch stream. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param event - Event for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddEventWaitNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUevent event); + +/** + * \brief Returns the event associated with an event wait node + * + * Returns the event of event wait node \p hNode in \p event_out. + * + * \param hNode - Node to get the event for + * \param event_out - Pointer to return the event + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventWaitNodeGetEvent(CUgraphNode hNode, CUevent *event_out); + +/** + * \brief Sets an event wait node's event + * + * Sets the event of event wait node \p hNode to \p event. + * + * \param hNode - Node to set the event for + * \param event - Event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent + */ +CUresult CUDAAPI cuGraphEventWaitNodeSetEvent(CUgraphNode hNode, CUevent event); + +/** + * \brief Creates an external semaphore signal node and adds it to a graph + * + * Creates a new external semaphore signal node and adds it to \p hGraph with \p + * numDependencies dependencies specified via \p dependencies and arguments specified + * in \p nodeParams. It is possible for \p numDependencies to be 0, in which case the + * node will be placed at the root of the graph. \p dependencies may not have any + * duplicate entries. A handle to the new node will be returned in \p phGraphNode. + * + * Performs a signal operation on a set of externally allocated semaphore objects + * when the node is launched. The operation(s) will occur after all of the node's + * dependencies have completed. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExternalSemaphoresSignalNodeGetParams, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddExternalSemaphoresSignalNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Returns an external semaphore signal node's parameters + * + * Returns the parameters of an external semaphore signal node \p hNode in \p params_out. + * The \p extSemArray and \p paramsArray returned in \p params_out, + * are owned by the node. This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphExternalSemaphoresSignalNodeSetParams to update the + * parameters of this node. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeGetParams(CUgraphNode hNode, CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *params_out); + +/** + * \brief Sets an external semaphore signal node's parameters + * + * Sets the parameters of an external semaphore signal node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphExternalSemaphoresSignalNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresSignalNodeSetParams(CUgraphNode hNode, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Creates an external semaphore wait node and adds it to a graph + * + * Creates a new external semaphore wait node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. A handle + * to the new node will be returned in \p phGraphNode. + * + * Performs a wait operation on a set of externally allocated semaphore objects + * when the node is launched. The node's dependencies will not be launched until + * the wait operation has completed. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphExternalSemaphoresWaitNodeGetParams, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode, + */ +CUresult CUDAAPI cuGraphAddExternalSemaphoresWaitNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Returns an external semaphore wait node's parameters + * + * Returns the parameters of an external semaphore wait node \p hNode in \p params_out. + * The \p extSemArray and \p paramsArray returned in \p params_out, + * are owned by the node. This memory remains valid until the node is destroyed or its + * parameters are modified, and should not be modified + * directly. Use ::cuGraphExternalSemaphoresSignalNodeSetParams to update the + * parameters of this node. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuLaunchKernel, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeGetParams(CUgraphNode hNode, CUDA_EXT_SEM_WAIT_NODE_PARAMS *params_out); + +/** + * \brief Sets an external semaphore wait node's parameters + * + * Sets the parameters of an external semaphore wait node \p hNode to \p nodeParams. + * + * \param hNode - Node to set the parameters for + * \param nodeParams - Parameters to copy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphExternalSemaphoresWaitNodeSetParams, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync + */ +CUresult CUDAAPI cuGraphExternalSemaphoresWaitNodeSetParams(CUgraphNode hNode, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Creates an allocation node and adds it to a graph + * + * Creates a new allocation node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. A handle + * to the new node will be returned in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param nodeParams - Parameters for the node + * + * When ::cuGraphAddMemAllocNode creates an allocation node, it returns the address of the allocation in + * \p nodeParams.dptr. The allocation's address remains fixed across instantiations and launches. + * + * If the allocation is freed in the same graph, by creating a free node using ::cuGraphAddMemFreeNode, + * the allocation can be accessed by nodes ordered after the allocation node but before the free node. + * These allocations cannot be freed outside the owning graph, and they can only be freed once in the + * owning graph. + * + * If the allocation is not freed in the same graph, then it can be accessed not only by nodes in the + * graph which are ordered after the allocation node, but also by stream operations ordered after the + * graph's execution but before the allocation is freed. + * + * Allocations which are not freed in the same graph can be freed by: + * - passing the allocation to ::cuMemFreeAsync or ::cuMemFree; + * - launching a graph with a free node for that allocation; or + * - specifying ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH during instantiation, which makes + * each launch behave as though it called ::cuMemFreeAsync for every unfreed allocation. + * + * It is not possible to free an allocation in both the owning graph and another graph. If the allocation + * is freed in the same graph, a free node cannot be added to another graph. If the allocation is freed + * in another graph, a free node can no longer be added to the owning graph. + * + * The following restrictions apply to graphs which contain allocation and/or memory free nodes: + * - Nodes and edges of the graph cannot be deleted. + * - The graph cannot be used in a child node. + * - Only one instantiation of the graph may exist at any point in time. + * - The graph cannot be cloned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemFreeNode, + * ::cuGraphMemAllocNodeGetParams, + * ::cuDeviceGraphMemTrim, + * ::cuDeviceGetGraphMemAttribute, + * ::cuDeviceSetGraphMemAttribute, + * ::cuMemAllocAsync, + * ::cuMemFreeAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemAllocNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUDA_MEM_ALLOC_NODE_PARAMS *nodeParams); + +/** + * \brief Returns a memory alloc node's parameters + * + * Returns the parameters of a memory alloc node \p hNode in \p params_out. + * The \p poolProps and \p accessDescs returned in \p params_out, are owned by the + * node. This memory remains valid until the node is destroyed. The returned + * parameters must not be modified. + * + * \param hNode - Node to get the parameters for + * \param params_out - Pointer to return the parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphMemFreeNodeGetParams + */ +CUresult CUDAAPI cuGraphMemAllocNodeGetParams(CUgraphNode hNode, CUDA_MEM_ALLOC_NODE_PARAMS *params_out); + +/** + * \brief Creates a memory free node and adds it to a graph + * + * Creates a new memory free node and adds it to \p hGraph with \p numDependencies + * dependencies specified via \p dependencies and arguments specified in \p nodeParams. + * It is possible for \p numDependencies to be 0, in which case the node will be placed + * at the root of the graph. \p dependencies may not have any duplicate entries. A handle + * to the new node will be returned in \p phGraphNode. + * + * \param phGraphNode - Returns newly created node + * \param hGraph - Graph to which to add the node + * \param dependencies - Dependencies of the node + * \param numDependencies - Number of dependencies + * \param dptr - Address of memory to free + * + * ::cuGraphAddMemFreeNode will return ::CUDA_ERROR_INVALID_VALUE if the user attempts to free: + * - an allocation twice in the same graph. + * - an address that was not returned by an allocation node. + * - an invalid address. + * + * The following restrictions apply to graphs which contain allocation and/or memory free nodes: + * - Nodes and edges of the graph cannot be deleted. + * - The graph cannot be used in a child node. + * - Only one instantiation of the graph may exist at any point in time. + * - The graph cannot be cloned. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphMemFreeNodeGetParams, + * ::cuDeviceGraphMemTrim, + * ::cuDeviceGetGraphMemAttribute, + * ::cuDeviceSetGraphMemAttribute, + * ::cuMemAllocAsync, + * ::cuMemFreeAsync, + * ::cuGraphCreate, + * ::cuGraphDestroyNode, + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddEventRecordNode, + * ::cuGraphAddEventWaitNode, + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphAddMemFreeNode(CUgraphNode *phGraphNode, CUgraph hGraph, const CUgraphNode *dependencies, size_t numDependencies, CUdeviceptr dptr); + +/** + * \brief Returns a memory free node's parameters + * + * Returns the address of a memory free node \p hNode in \p dptr_out. + * + * \param hNode - Node to get the parameters for + * \param dptr_out - Pointer to return the device address + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemFreeNode, + * ::cuGraphMemAllocNodeGetParams + */ +CUresult CUDAAPI cuGraphMemFreeNodeGetParams(CUgraphNode hNode, CUdeviceptr *dptr_out); + +/** + * \brief Free unused memory that was cached on the specified device for use with graphs back to the OS. + * + * Blocks which are not in use by a graph that is either currently executing or scheduled to execute are + * freed back to the operating system. + * + * \param device - The device for which cached memory should be freed. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode, + * ::cuDeviceSetGraphMemAttribute, + * ::cuDeviceGetGraphMemAttribute + */ +CUresult CUDAAPI cuDeviceGraphMemTrim(CUdevice device); + +/** + * \brief Query asynchronous allocation attributes related to graphs + * + * Valid attributes are: + * + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_CURRENT: Amount of memory, in bytes, currently associated with graphs + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_HIGH: High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_CURRENT: Amount of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH: High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + * + * \param device - Specifies the scope of the query + * \param attr - attribute to get + * \param value - retrieved value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDeviceSetGraphMemAttribute, + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode + */ +CUresult CUDAAPI cuDeviceGetGraphMemAttribute(CUdevice device, CUgraphMem_attribute attr, void* value); + +/** + * \brief Set asynchronous allocation attributes related to graphs + * + * Valid attributes are: + * + * - ::CU_GRAPH_MEM_ATTR_USED_MEM_HIGH: High watermark of memory, in bytes, associated with graphs since the + * last time it was reset. High watermark can only be reset to zero. + * - ::CU_GRAPH_MEM_ATTR_RESERVED_MEM_HIGH: High watermark of memory, in bytes, currently allocated for use by + * the CUDA graphs asynchronous allocator. + * + * \param device - Specifies the scope of the query + * \param attr - attribute to get + * \param value - pointer to value to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_DEVICE + * + * \sa + * ::cuDeviceGetGraphMemAttribute, + * ::cuGraphAddMemAllocNode, + * ::cuGraphAddMemFreeNode + */ +CUresult CUDAAPI cuDeviceSetGraphMemAttribute(CUdevice device, CUgraphMem_attribute attr, void* value); + +/** + * \brief Clones a graph + * + * This function creates a copy of \p originalGraph and returns it in \p phGraphClone. + * All parameters are copied into the cloned graph. The original graph may be modified + * after this call without affecting the clone. + * + * Child graph nodes in the original graph are recursively copied into the clone. + * + * \param phGraphClone - Returns newly created cloned graph + * \param originalGraph - Graph to clone + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OUT_OF_MEMORY + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphNodeFindInClone + */ +CUresult CUDAAPI cuGraphClone(CUgraph *phGraphClone, CUgraph originalGraph); + +/** + * \brief Finds a cloned version of a node + * + * This function returns the node in \p hClonedGraph corresponding to \p hOriginalNode + * in the original graph. + * + * \p hClonedGraph must have been cloned from \p hOriginalGraph via ::cuGraphClone. + * \p hOriginalNode must have been in \p hOriginalGraph at the time of the call to + * ::cuGraphClone, and the corresponding cloned node in \p hClonedGraph must not have + * been removed. The cloned node is then returned via \p phClonedNode. + * + * \param phNode - Returns handle to the cloned node + * \param hOriginalNode - Handle to the original node + * \param hClonedGraph - Cloned graph to query + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphClone + */ +CUresult CUDAAPI cuGraphNodeFindInClone(CUgraphNode *phNode, CUgraphNode hOriginalNode, CUgraph hClonedGraph); + +/** + * \brief Returns a node's type + * + * Returns the node type of \p hNode in \p type. + * + * \param hNode - Node to query + * \param type - Pointer to return the node type + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphKernelNodeGetParams, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphHostNodeGetParams, + * ::cuGraphHostNodeSetParams, + * ::cuGraphMemcpyNodeGetParams, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphMemsetNodeGetParams, + * ::cuGraphMemsetNodeSetParams + */ +CUresult CUDAAPI cuGraphNodeGetType(CUgraphNode hNode, CUgraphNodeType *type); + +/** + * \brief Returns a graph's nodes + * + * Returns a list of \p hGraph's nodes. \p nodes may be NULL, in which case this + * function will return the number of nodes in \p numNodes. Otherwise, + * \p numNodes entries will be filled in. If \p numNodes is higher than the actual + * number of nodes, the remaining entries in \p nodes will be set to NULL, and the + * number of nodes actually obtained will be returned in \p numNodes. + * + * \param hGraph - Graph to query + * \param nodes - Pointer to return the nodes + * \param numNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetType, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetNodes(CUgraph hGraph, CUgraphNode *nodes, size_t *numNodes); + +/** + * \brief Returns a graph's root nodes + * + * Returns a list of \p hGraph's root nodes. \p rootNodes may be NULL, in which case this + * function will return the number of root nodes in \p numRootNodes. Otherwise, + * \p numRootNodes entries will be filled in. If \p numRootNodes is higher than the actual + * number of root nodes, the remaining entries in \p rootNodes will be set to NULL, and the + * number of nodes actually obtained will be returned in \p numRootNodes. + * + * \param hGraph - Graph to query + * \param rootNodes - Pointer to return the root nodes + * \param numRootNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate, + * ::cuGraphGetNodes, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetType, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetRootNodes(CUgraph hGraph, CUgraphNode *rootNodes, size_t *numRootNodes); + +/** + * \brief Returns a graph's dependency edges + * + * Returns a list of \p hGraph's dependency edges. Edges are returned via corresponding + * indices in \p from and \p to; that is, the node in \p to[i] has a dependency on the + * node in \p from[i]. \p from and \p to may both be NULL, in which + * case this function only returns the number of edges in \p numEdges. Otherwise, + * \p numEdges entries will be filled in. If \p numEdges is higher than the actual + * number of edges, the remaining entries in \p from and \p to will be set to NULL, and + * the number of edges actually returned will be written to \p numEdges. + * + * \param hGraph - Graph to get the edges from + * \param from - Location to return edge endpoints + * \param to - Location to return edge endpoints + * \param numEdges - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphGetEdges(CUgraph hGraph, CUgraphNode *from, CUgraphNode *to, size_t *numEdges); + +/** + * \brief Returns a node's dependencies + * + * Returns a list of \p node's dependencies. \p dependencies may be NULL, in which case this + * function will return the number of dependencies in \p numDependencies. Otherwise, + * \p numDependencies entries will be filled in. If \p numDependencies is higher than the actual + * number of dependencies, the remaining entries in \p dependencies will be set to NULL, and the + * number of nodes actually obtained will be returned in \p numDependencies. + * + * \param hNode - Node to query + * \param dependencies - Pointer to return the dependencies + * \param numDependencies - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetDependentNodes, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies + */ +CUresult CUDAAPI cuGraphNodeGetDependencies(CUgraphNode hNode, CUgraphNode *dependencies, size_t *numDependencies); + +/** + * \brief Returns a node's dependent nodes + * + * Returns a list of \p node's dependent nodes. \p dependentNodes may be NULL, in which + * case this function will return the number of dependent nodes in \p numDependentNodes. + * Otherwise, \p numDependentNodes entries will be filled in. If \p numDependentNodes is + * higher than the actual number of dependent nodes, the remaining entries in + * \p dependentNodes will be set to NULL, and the number of nodes actually obtained will + * be returned in \p numDependentNodes. + * + * \param hNode - Node to query + * \param dependentNodes - Pointer to return the dependent nodes + * \param numDependentNodes - See description + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetDependencies, + * ::cuGraphGetNodes, + * ::cuGraphGetRootNodes, + * ::cuGraphGetEdges, + * ::cuGraphAddDependencies, + * ::cuGraphRemoveDependencies + */ +CUresult CUDAAPI cuGraphNodeGetDependentNodes(CUgraphNode hNode, CUgraphNode *dependentNodes, size_t *numDependentNodes); + +/** + * \brief Adds dependency edges to a graph + * + * The number of dependencies to be added is defined by \p numDependencies + * Elements in \p from and \p to at corresponding indices define a dependency. + * Each node in \p from and \p to must belong to \p hGraph. + * + * If \p numDependencies is 0, elements in \p from and \p to will be ignored. + * Specifying an existing dependency will return an error. + * + * \param hGraph - Graph to which dependencies are added + * \param from - Array of nodes that provide the dependencies + * \param to - Array of dependent nodes + * \param numDependencies - Number of dependencies to be added + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphRemoveDependencies, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphAddDependencies(CUgraph hGraph, const CUgraphNode *from, const CUgraphNode *to, size_t numDependencies); + +/** + * \brief Removes dependency edges from a graph + * + * The number of \p dependencies to be removed is defined by \p numDependencies. + * Elements in \p from and \p to at corresponding indices define a dependency. + * Each node in \p from and \p to must belong to \p hGraph. + * + * If \p numDependencies is 0, elements in \p from and \p to will be ignored. + * Specifying a non-existing dependency will return an error. + * + * Dependencies cannot be removed from graphs which contain allocation or free nodes. + * Any attempt to do so will return an error. + * + * \param hGraph - Graph from which to remove dependencies + * \param from - Array of nodes that provide the dependencies + * \param to - Array of dependent nodes + * \param numDependencies - Number of dependencies to be removed + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddDependencies, + * ::cuGraphGetEdges, + * ::cuGraphNodeGetDependencies, + * ::cuGraphNodeGetDependentNodes + */ +CUresult CUDAAPI cuGraphRemoveDependencies(CUgraph hGraph, const CUgraphNode *from, const CUgraphNode *to, size_t numDependencies); + +/** + * \brief Remove a node from the graph + * + * Removes \p hNode from its graph. This operation also severs any dependencies of other nodes + * on \p hNode and vice versa. + * + * Nodes which belong to a graph which contains allocation or free nodes cannot be destroyed. + * Any attempt to do so will return an error. + * + * \param hNode - Node to remove + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphAddEmptyNode, + * ::cuGraphAddKernelNode, + * ::cuGraphAddHostNode, + * ::cuGraphAddMemcpyNode, + * ::cuGraphAddMemsetNode + */ +CUresult CUDAAPI cuGraphDestroyNode(CUgraphNode hNode); + +/** + * \brief Creates an executable graph from a graph + * + * Instantiates \p hGraph as an executable graph. The graph is validated for any + * structural constraints or intra-node constraints which were not previously + * validated. If instantiation is successful, a handle to the instantiated graph + * is returned in \p phGraphExec. + * + * If there are any errors, diagnostic information may be returned in \p errorNode and + * \p logBuffer. This is the primary way to inspect instantiation errors. The output + * will be null terminated unless the diagnostics overflow + * the buffer. In this case, they will be truncated, and the last byte can be + * inspected to determine if truncation occurred. + * + * \param phGraphExec - Returns instantiated graph + * \param hGraph - Graph to instantiate + * \param phErrorNode - In case of an instantiation error, this may be modified to + * indicate a node contributing to the error + * \param logBuffer - A character buffer to store diagnostic messages + * \param bufferSize - Size of the log buffer in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiateWithFlags, + * ::cuGraphCreate, + * ::cuGraphUpload, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); + +/** + * \brief Creates an executable graph from a graph + * + * Instantiates \p hGraph as an executable graph. The graph is validated for any + * structural constraints or intra-node constraints which were not previously + * validated. If instantiation is successful, a handle to the instantiated graph + * is returned in \p phGraphExec. + * + * The \p flags parameter controls the behavior of instantiation and subsequent + * graph launches. Valid flags are: + * + * - ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, which configures a + * graph containing memory allocation nodes to automatically free any + * unfreed memory allocations before the graph is relaunched. + * + * If \p hGraph contains any allocation or free nodes, there can be at most one + * executable graph in existence for that graph at a time. + * + * An attempt to instantiate a second executable graph before destroying the first + * with ::cuGraphExecDestroy will result in an error. + * + * \param phGraphExec - Returns instantiated graph + * \param hGraph - Graph to instantiate + * \param flags - Flags to control instantiation. See ::CUgraphInstantiate_flags. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphCreate, + * ::cuGraphUpload, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphInstantiateWithFlags(CUgraphExec *phGraphExec, CUgraph hGraph, unsigned long long flags); + +/** + * \brief Sets the parameters for a kernel node in the given graphExec + * + * Sets the parameters of a kernel node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. All \p nodeParams + * fields may change, but the following restrictions apply to \p func updates: + * + * - The owning context of the function cannot change. + * - A node whose function originally did not use CUDA dynamic parallelism cannot be updated + * to a function which uses CDP + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - kernel node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddKernelNode, + * ::cuGraphKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecKernelNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_KERNEL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for a memcpy node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode had + * contained \p copyParams at instantiation. hNode must remain in the graph which was + * used to instantiate \p hGraphExec. Changed edges to and from hNode are ignored. + * + * The source and destination memory in \p copyParams must be allocated from the same + * contexts as the original source and destination memory. Both the instantiation-time + * memory operands and the memory operands in \p copyParams must be 1-dimensional. + * Zero-length operations are not supported. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. hNode is also + * not modified by this call. + * + * Returns CUDA_ERROR_INVALID_VALUE if the memory operands' mappings changed or + * either the original or new memory operands are multidimensional. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Memcpy node from the graph which was used to instantiate graphExec + * \param copyParams - The updated parameters to set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemcpyNode, + * ::cuGraphMemcpyNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecMemcpyNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_MEMCPY3D *copyParams, CUcontext ctx); + +/** + * \brief Sets the parameters for a memset node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode had + * contained \p memsetParams at instantiation. hNode must remain in the graph which was + * used to instantiate \p hGraphExec. Changed edges to and from hNode are ignored. + * + * The destination memory in \p memsetParams must be allocated from the same + * contexts as the original destination memory. Both the instantiation-time + * memory operand and the memory operand in \p memsetParams must be 1-dimensional. + * Zero-length operations are not supported. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. hNode is also + * not modified by this call. + * + * Returns CUDA_ERROR_INVALID_VALUE if the memory operand's mappings changed or + * either the original or new memory operand are multidimensional. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Memset node from the graph which was used to instantiate graphExec + * \param memsetParams - The updated parameters to set + * \param ctx - Context on which to run the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddMemsetNode, + * ::cuGraphMemsetNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecMemsetNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_MEMSET_NODE_PARAMS *memsetParams, CUcontext ctx); + +/** + * \brief Sets the parameters for a host node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though \p hNode had + * contained \p nodeParams at instantiation. hNode must remain in the graph which was + * used to instantiate \p hGraphExec. Changed edges to and from hNode are ignored. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. hNode is also + * not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Host node from the graph which was used to instantiate graphExec + * \param nodeParams - The updated parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddHostNode, + * ::cuGraphHostNodeSetParams, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecHostNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_HOST_NODE_PARAMS *nodeParams); + +/** + * \brief Updates node parameters in the child graph node in the given graphExec. + * + * Updates the work represented by \p hNode in \p hGraphExec as though the nodes contained + * in \p hNode's graph had the parameters contained in \p childGraph's nodes at instantiation. + * \p hNode must remain in the graph which was used to instantiate \p hGraphExec. + * Changed edges to and from \p hNode are ignored. + * + * The modifications only affect future launches of \p hGraphExec. Already enqueued + * or running launches of \p hGraphExec are not affected by this call. \p hNode is also + * not modified by this call. + * + * The topology of \p childGraph, as well as the node insertion order, must match that + * of the graph contained in \p hNode. See ::cuGraphExecUpdate() for a list of restrictions + * on what can be updated in an instantiated graph. The update is recursive, so child graph + * nodes contained within the top level child graph will also be updated. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Host node from the graph which was used to instantiate graphExec + * \param childGraph - The graph supplying the updated parameters + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddChildGraphNode, + * ::cuGraphChildGraphNodeGetGraph, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecChildGraphNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, CUgraph childGraph); + +/** + * \brief Sets the event for an event record node in the given graphExec + * + * Sets the event of an event record node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - event record node from the graph from which graphExec was instantiated + * \param event - Updated event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventRecordNode, + * ::cuGraphEventRecordNodeGetEvent, + * ::cuGraphEventWaitNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecEventRecordNodeSetEvent(CUgraphExec hGraphExec, CUgraphNode hNode, CUevent event); + +/** + * \brief Sets the event for an event wait node in the given graphExec + * + * Sets the event of an event wait node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - event wait node from the graph from which graphExec was instantiated + * \param event - Updated event to use + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddEventWaitNode, + * ::cuGraphEventWaitNodeGetEvent, + * ::cuGraphEventRecordNodeSetEvent, + * ::cuEventRecordWithFlags, + * ::cuStreamWaitEvent, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecEventWaitNodeSetEvent(CUgraphExec hGraphExec, CUgraphNode hNode, CUevent event); + +/** + * \brief Sets the parameters for an external semaphore signal node in the given graphExec + * + * Sets the parameters of an external semaphore signal node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Changing \p nodeParams->numExtSems is not supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - semaphore signal node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresSignalNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresWaitNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecExternalSemaphoresSignalNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_EXT_SEM_SIGNAL_NODE_PARAMS *nodeParams); + +/** + * \brief Sets the parameters for an external semaphore wait node in the given graphExec + * + * Sets the parameters of an external semaphore wait node in an executable graph \p hGraphExec. + * The node is identified by the corresponding node \p hNode in the + * non-executable graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * Changing \p nodeParams->numExtSems is not supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - semaphore wait node from the graph from which graphExec was instantiated + * \param nodeParams - Updated Parameters to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphAddExternalSemaphoresWaitNode, + * ::cuImportExternalSemaphore, + * ::cuSignalExternalSemaphoresAsync, + * ::cuWaitExternalSemaphoresAsync, + * ::cuGraphExecKernelNodeSetParams, + * ::cuGraphExecMemcpyNodeSetParams, + * ::cuGraphExecMemsetNodeSetParams, + * ::cuGraphExecHostNodeSetParams, + * ::cuGraphExecChildGraphNodeSetParams, + * ::cuGraphExecEventRecordNodeSetEvent, + * ::cuGraphExecEventWaitNodeSetEvent, + * ::cuGraphExecExternalSemaphoresSignalNodeSetParams, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + */ +CUresult CUDAAPI cuGraphExecExternalSemaphoresWaitNodeSetParams(CUgraphExec hGraphExec, CUgraphNode hNode, const CUDA_EXT_SEM_WAIT_NODE_PARAMS *nodeParams); + +/** + * \brief Enables or disables the specified node in the given graphExec + * + * Sets \p hNode to be either enabled or disabled. Disabled nodes are functionally equivalent + * to empty nodes until they are reenabled. Existing node parameters are not affected by + * disabling/enabling the node. + * + * The node is identified by the corresponding node \p hNode in the non-executable + * graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * The modifications only affect future launches of \p hGraphExec. Already + * enqueued or running launches of \p hGraphExec are not affected by this call. + * \p hNode is also not modified by this call. + * + * \note Currently only kernel nodes are supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Node from the graph from which graphExec was instantiated + * \param isEnabled - Node is enabled if != 0, otherwise the node is disabled + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeGetEnabled, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + * ::cuGraphLaunch + */ +CUresult CUDAAPI cuGraphNodeSetEnabled(CUgraphExec hGraphExec, CUgraphNode hNode, unsigned int isEnabled); + +/** + * \brief Query whether a node in the given graphExec is enabled + * + * Sets isEnabled to 1 if \p hNode is enabled, or 0 if \p hNode is disabled. + * + * The node is identified by the corresponding node \p hNode in the non-executable + * graph, from which the executable graph was instantiated. + * + * \p hNode must not have been removed from the original graph. + * + * \note Currently only kernel nodes are supported. + * + * \param hGraphExec - The executable graph in which to set the specified node + * \param hNode - Node from the graph from which graphExec was instantiated + * \param isEnabled - Location to return the enabled status of the node + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphNodeSetEnabled, + * ::cuGraphExecUpdate, + * ::cuGraphInstantiate + * ::cuGraphLaunch + */ +CUresult CUDAAPI cuGraphNodeGetEnabled(CUgraphExec hGraphExec, CUgraphNode hNode, unsigned int *isEnabled); + +/** + * \brief Uploads an executable graph in a stream + * + * Uploads \p hGraphExec to the device in \p hStream without executing it. Uploads of + * the same \p hGraphExec will be serialized. Each upload is ordered behind both any + * previous work in \p hStream and any previous launches of \p hGraphExec. + * Uses memory cached by \p stream to back the allocations owned by \p hGraphExec. + * + * \param hGraphExec - Executable graph to upload + * \param hStream - Stream in which to upload the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphLaunch, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraphExec, CUstream hStream); + +/** + * \brief Launches an executable graph in a stream + * + * Executes \p hGraphExec in \p hStream. Only one instance of \p hGraphExec may be executing + * at a time. Each launch is ordered behind both any previous work in \p hStream + * and any previous launches of \p hGraphExec. To execute a graph concurrently, it must be + * instantiated multiple times into multiple executable graphs. + * + * If any allocations created by \p hGraphExec remain unfreed (from a previous launch) and + * \p hGraphExec was not instantiated with ::CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH, + * the launch will fail with ::CUDA_ERROR_INVALID_VALUE. + * + * \param hGraphExec - Executable graph to launch + * \param hStream - Stream in which to launch the graph + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphUpload, + * ::cuGraphExecDestroy + */ +CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraphExec, CUstream hStream); + +/** + * \brief Destroys an executable graph + * + * Destroys the executable graph specified by \p hGraphExec, as well + * as all of its executable nodes. If the executable graph is + * in-flight, it will not be terminated, but rather freed + * asynchronously on completion. + * + * \param hGraphExec - Executable graph to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + * ::cuGraphUpload, + * ::cuGraphLaunch + */ +CUresult CUDAAPI cuGraphExecDestroy(CUgraphExec hGraphExec); + +/** + * \brief Destroys a graph + * + * Destroys the graph specified by \p hGraph, as well as all of its nodes. + * + * \param hGraph - Graph to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_VALUE + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphDestroy(CUgraph hGraph); + +/** + * \brief Check whether an executable graph can be updated with a graph and perform the update if possible + * + * Updates the node parameters in the instantiated graph specified by \p hGraphExec with the + * node parameters in a topologically identical graph specified by \p hGraph. + * + * Limitations: + * + * - Kernel nodes: + * - The owning context of the function cannot change. + * - A node whose function originally did not use CUDA dynamic parallelism cannot be updated + * to a function which uses CDP. + * - A cooperative node cannot be updated to a non-cooperative node, and vice-versa. + * - Memset and memcpy nodes: + * - The CUDA device(s) to which the operand(s) was allocated/mapped cannot change. + * - The source/destination memory must be allocated from the same contexts as the original + * source/destination memory. + * - Only 1D memsets can be changed. + * - Additional memcpy node restrictions: + * - Changing either the source or destination memory type(i.e. CU_MEMORYTYPE_DEVICE, + * CU_MEMORYTYPE_ARRAY, etc.) is not supported. + * - External semaphore wait nodes and record nodes: + * - Changing the number of semaphores is not supported. + * + * Note: The API may add further restrictions in future releases. The return code should always be checked. + * + * cuGraphExecUpdate sets \p updateResult_out to CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED under + * the following conditions: + * + * - The count of nodes directly in \p hGraphExec and \p hGraph differ, in which case \p hErrorNode_out + * is NULL. + * - A node is deleted in \p hGraph but not not its pair from \p hGraphExec, in which case \p hErrorNode_out + * is NULL. + * - A node is deleted in \p hGraphExec but not its pair from \p hGraph, in which case \p hErrorNode_out is + * the pairless node from \p hGraph. + * - The dependent nodes of a pair differ, in which case \p hErrorNode_out is the node from \p hGraph. + * + * cuGraphExecUpdate sets \p updateResult_out to: + * - CU_GRAPH_EXEC_UPDATE_ERROR if passed an invalid value. + * - CU_GRAPH_EXEC_UPDATE_ERROR_TOPOLOGY_CHANGED if the graph topology changed + * - CU_GRAPH_EXEC_UPDATE_ERROR_NODE_TYPE_CHANGED if the type of a node changed, in which case + * \p hErrorNode_out is set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_UNSUPPORTED_FUNCTION_CHANGE if the function changed in an unsupported + * way(see note above), in which case \p hErrorNode_out is set to the node from \p hGraph + * - CU_GRAPH_EXEC_UPDATE_ERROR_PARAMETERS_CHANGED if any parameters to a node changed in a way + * that is not supported, in which case \p hErrorNode_out is set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_ATTRIBUTES_CHANGED if any attributes of a node changed in a way + * that is not supported, in which case \p hErrorNode_out is set to the node from \p hGraph. + * - CU_GRAPH_EXEC_UPDATE_ERROR_NOT_SUPPORTED if something about a node is unsupported, like + * the node's type or configuration, in which case \p hErrorNode_out is set to the node from \p hGraph + * + * If \p updateResult_out isn't set in one of the situations described above, the update check passes + * and cuGraphExecUpdate updates \p hGraphExec to match the contents of \p hGraph. If an error happens + * during the update, \p updateResult_out will be set to CU_GRAPH_EXEC_UPDATE_ERROR; otherwise, + * \p updateResult_out is set to CU_GRAPH_EXEC_UPDATE_SUCCESS. + * + * cuGraphExecUpdate returns CUDA_SUCCESS when the updated was performed successfully. It returns + * CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE if the graph update was not performed because it included + * changes which violated constraints specific to instantiated graph update. + * + * \param hGraphExec The instantiated graph to be updated + * \param hGraph The graph containing the updated parameters + * \param hErrorNode_out The node which caused the permissibility check to forbid the update, if any + * \param updateResult_out Whether the graph update was permitted. If was forbidden, the reason why + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE, + * \note_graph_thread_safety + * \notefnerr + * + * \sa + * ::cuGraphInstantiate, + */ +CUresult CUDAAPI cuGraphExecUpdate(CUgraphExec hGraphExec, CUgraph hGraph, CUgraphNode *hErrorNode_out, CUgraphExecUpdateResult *updateResult_out); + +/** + * \brief Copies attributes from source node to destination node. + * + * Copies attributes from source node \p src to destination node \p dst. + * Both node must have the same context. + * + * \param[out] dst Destination node + * \param[in] src Source node + * For list of attributes see ::CUkernelNodeAttrID + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeCopyAttributes(CUgraphNode dst, CUgraphNode src); + +/** + * \brief Queries node attribute. + * + * Queries attribute \p attr from node \p hNode and stores it in corresponding + * member of \p value_out. + * + * \param[in] hNode + * \param[in] attr + * \param[out] value_out + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeGetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + CUkernelNodeAttrValue *value_out); + +/** + * \brief Sets node attribute. + * + * Sets attribute \p attr on node \p hNode from corresponding attribute of + * \p value. + * + * \param[out] hNode + * \param[in] attr + * \param[out] value + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE + * \notefnerr + * + * \sa + * ::CUaccessPolicyWindow + */ +CUresult CUDAAPI cuGraphKernelNodeSetAttribute(CUgraphNode hNode, CUkernelNodeAttrID attr, + const CUkernelNodeAttrValue *value); + +/** + * \brief Write a DOT file describing graph structure + * + * Using the provided \p hGraph, write to \p path a DOT formatted description of the graph. + * By default this includes the graph topology, node types, node id, kernel names and memcpy direction. + * \p flags can be specified to write more detailed information about each node type such as + * parameter values, kernel attributes, node and function handles. + * + * \param hGraph - The graph to create a DOT file from + * \param path - The path to write the DOT file to + * \param flags - Flags from CUgraphDebugDot_flags for specifying which additional node information to write + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_OPERATING_SYSTEM + */ +CUresult CUDAAPI cuGraphDebugDotPrint(CUgraph hGraph, const char *path, unsigned int flags); + +/** + * \brief Create a user object + * + * Create a user object with the specified destructor callback and initial reference count. The + * initial references are owned by the caller. + * + * Destructor callbacks cannot make CUDA API calls and should avoid blocking behavior, as they + * are executed by a shared internal thread. Another thread may be signaled to perform such + * actions, if it does not block forward progress of tasks scheduled through CUDA. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param object_out - Location to return the user object handle + * \param ptr - The pointer to pass to the destroy function + * \param destroy - Callback to free the user object when it is no longer in use + * \param initialRefcount - The initial refcount to create the object with, typically 1. The + * initial references are owned by the calling thread. + * \param flags - Currently it is required to pass ::CU_USER_OBJECT_NO_DESTRUCTOR_SYNC, + * which is the only defined flag. This indicates that the destroy + * callback cannot be waited on by any CUDA API. Users requiring + * synchronization of the callback should signal its completion + * manually. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectCreate(CUuserObject *object_out, void *ptr, CUhostFn destroy, + unsigned int initialRefcount, unsigned int flags); + +/** + * \brief Retain a reference to a user object + * + * Retains new references to a user object. The new references are owned by the caller. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param object - The object to retain + * \param count - The number of references to retain, typically 1. Must be nonzero + * and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectRetain(CUuserObject object, unsigned int count); + +/** + * \brief Release a reference to a user object + * + * Releases user object references owned by the caller. The object's destructor is invoked if + * the reference count reaches zero. + * + * It is undefined behavior to release references not owned by the caller, or to use a user + * object handle after all references are released. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param object - The object to release + * \param count - The number of references to release, typically 1. Must be nonzero + * and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuGraphRetainUserObject, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuUserObjectRelease(CUuserObject object, unsigned int count); + +/** + * \brief Retain a reference to a user object from a graph + * + * Creates or moves user object references that will be owned by a CUDA graph. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param graph - The graph to associate the reference with + * \param object - The user object to retain a reference for + * \param count - The number of references to add to the graph, typically 1. Must be + * nonzero and not larger than INT_MAX. + * \param flags - The optional flag ::CU_GRAPH_USER_OBJECT_MOVE transfers references + * from the calling thread, rather than create new references. Pass 0 + * to create new references. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphReleaseUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphRetainUserObject(CUgraph graph, CUuserObject object, unsigned int count, unsigned int flags); + +/** + * \brief Release a user object reference from a graph + * + * Releases user object references owned by a graph. + * + * See CUDA User Objects in the CUDA C++ Programming Guide for more information on user objects. + * + * \param graph - The graph that will release the reference + * \param object - The user object to release a reference for + * \param count - The number of references to release, typically 1. Must be nonzero + * and not larger than INT_MAX. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuUserObjectCreate, + * ::cuUserObjectRetain, + * ::cuUserObjectRelease, + * ::cuGraphRetainUserObject, + * ::cuGraphCreate + */ +CUresult CUDAAPI cuGraphReleaseUserObject(CUgraph graph, CUuserObject object, unsigned int count); + +/** @} */ /* END CUDA_GRAPH */ + +/** + * \defgroup CUDA_OCCUPANCY Occupancy + * + * ___MANBRIEF___ occupancy calculation functions of the low-level CUDA driver + * API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the occupancy calculation functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns occupancy of a function + * + * Returns in \p *numBlocks the number of the maximum active blocks per + * streaming multiprocessor. + * + * \param numBlocks - Returned occupancy + * \param func - Kernel for which occupancy is calculated + * \param blockSize - Block size the kernel is intended to be launched with + * \param dynamicSMemSize - Per-block dynamic shared memory usage intended, in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxActiveBlocksPerMultiprocessor + */ +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessor(int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize); + +/** + * \brief Returns occupancy of a function + * + * Returns in \p *numBlocks the number of the maximum active blocks per + * streaming multiprocessor. + * + * The \p Flags parameter controls how special cases are handled. The + * valid flags are: + * + * - ::CU_OCCUPANCY_DEFAULT, which maintains the default behavior as + * ::cuOccupancyMaxActiveBlocksPerMultiprocessor; + * + * - ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE, which suppresses the + * default behavior on platform where global caching affects + * occupancy. On such platforms, if caching is enabled, but + * per-block SM resource usage would result in zero occupancy, the + * occupancy calculator will calculate the occupancy as if caching + * is disabled. Setting ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE makes + * the occupancy calculator to return 0 in such cases. More information + * can be found about this feature in the "Unified L1/Texture Cache" + * section of the Maxwell tuning guide. + * + * \param numBlocks - Returned occupancy + * \param func - Kernel for which occupancy is calculated + * \param blockSize - Block size the kernel is intended to be launched with + * \param dynamicSMemSize - Per-block dynamic shared memory usage intended, in bytes + * \param flags - Requested behavior for the occupancy calculator + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFlags + */ +CUresult CUDAAPI cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(int *numBlocks, CUfunction func, int blockSize, size_t dynamicSMemSize, unsigned int flags); + +/** + * \brief Suggest a launch configuration with reasonable occupancy + * + * Returns in \p *blockSize a reasonable block size that can achieve + * the maximum occupancy (or, the maximum number of active warps with + * the fewest blocks per multiprocessor), and in \p *minGridSize the + * minimum grid size to achieve the maximum occupancy. + * + * If \p blockSizeLimit is 0, the configurator will use the maximum + * block size permitted by the device / function instead. + * + * If per-block dynamic shared memory allocation is not needed, the + * user should leave both \p blockSizeToDynamicSMemSize and \p + * dynamicSMemSize as 0. + * + * If per-block dynamic shared memory allocation is needed, then if + * the dynamic shared memory size is constant regardless of block + * size, the size should be passed through \p dynamicSMemSize, and \p + * blockSizeToDynamicSMemSize should be NULL. + * + * Otherwise, if the per-block dynamic shared memory size varies with + * different block sizes, the user needs to provide a unary function + * through \p blockSizeToDynamicSMemSize that computes the dynamic + * shared memory needed by \p func for any given block size. \p + * dynamicSMemSize is ignored. An example signature is: + * + * \code + * // Take block size, returns dynamic shared memory needed + * size_t blockToSmem(int blockSize); + * \endcode + * + * \param minGridSize - Returned minimum grid size needed to achieve the maximum occupancy + * \param blockSize - Returned maximum block size that can achieve the maximum occupancy + * \param func - Kernel for which launch configuration is calculated + * \param blockSizeToDynamicSMemSize - A function that calculates how much per-block dynamic shared memory \p func uses based on the block size + * \param dynamicSMemSize - Dynamic shared memory usage intended, in bytes + * \param blockSizeLimit - The maximum block size \p func is designed to handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxPotentialBlockSize + */ +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSize(int *minGridSize, int *blockSize, CUfunction func, CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, int blockSizeLimit); + +/** + * \brief Suggest a launch configuration with reasonable occupancy + * + * An extended version of ::cuOccupancyMaxPotentialBlockSize. In + * addition to arguments passed to ::cuOccupancyMaxPotentialBlockSize, + * ::cuOccupancyMaxPotentialBlockSizeWithFlags also takes a \p Flags + * parameter. + * + * The \p Flags parameter controls how special cases are handled. The + * valid flags are: + * + * - ::CU_OCCUPANCY_DEFAULT, which maintains the default behavior as + * ::cuOccupancyMaxPotentialBlockSize; + * + * - ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE, which suppresses the + * default behavior on platform where global caching affects + * occupancy. On such platforms, the launch configurations that + * produces maximal occupancy might not support global + * caching. Setting ::CU_OCCUPANCY_DISABLE_CACHING_OVERRIDE + * guarantees that the the produced launch configuration is global + * caching compatible at a potential cost of occupancy. More information + * can be found about this feature in the "Unified L1/Texture Cache" + * section of the Maxwell tuning guide. + * + * \param minGridSize - Returned minimum grid size needed to achieve the maximum occupancy + * \param blockSize - Returned maximum block size that can achieve the maximum occupancy + * \param func - Kernel for which launch configuration is calculated + * \param blockSizeToDynamicSMemSize - A function that calculates how much per-block dynamic shared memory \p func uses based on the block size + * \param dynamicSMemSize - Dynamic shared memory usage intended, in bytes + * \param blockSizeLimit - The maximum block size \p func is designed to handle + * \param flags - Options + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cudaOccupancyMaxPotentialBlockSizeWithFlags + */ +CUresult CUDAAPI cuOccupancyMaxPotentialBlockSizeWithFlags(int *minGridSize, int *blockSize, CUfunction func, CUoccupancyB2DSize blockSizeToDynamicSMemSize, size_t dynamicSMemSize, int blockSizeLimit, unsigned int flags); + +/** + * \brief Returns dynamic shared memory available per block when launching \p numBlocks blocks on SM + * + * Returns in \p *dynamicSmemSize the maximum size of dynamic shared memory to allow \p numBlocks blocks per SM. + * + * \param dynamicSmemSize - Returned maximum dynamic shared memory + * \param func - Kernel function for which occupancy is calculated + * \param numBlocks - Number of blocks to fit on SM + * \param blockSize - Size of the blocks + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + */ +CUresult CUDAAPI cuOccupancyAvailableDynamicSMemPerBlock(size_t *dynamicSmemSize, CUfunction func, int numBlocks, int blockSize); + +/** @} */ /* END CUDA_OCCUPANCY */ + +/** + * \defgroup CUDA_TEXREF_DEPRECATED Texture Reference Management [DEPRECATED] + * + * ___MANBRIEF___ deprecated texture reference management functions of the + * low-level CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the deprecated texture reference management + * functions of the low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Binds an array as a texture reference + * + * \deprecated + * + * Binds the CUDA array \p hArray to the texture reference \p hTexRef. Any + * previous address or CUDA array state associated with the texture reference + * is superseded by this function. \p Flags must be set to + * ::CU_TRSA_OVERRIDE_FORMAT. Any CUDA array previously bound to \p hTexRef is + * unbound. + * + * \param hTexRef - Texture reference to bind + * \param hArray - Array to bind + * \param Flags - Options (must be ::CU_TRSA_OVERRIDE_FORMAT) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetArray(CUtexref hTexRef, CUarray hArray, unsigned int Flags); + +/** + * \brief Binds a mipmapped array to a texture reference + * + * \deprecated + * + * Binds the CUDA mipmapped array \p hMipmappedArray to the texture reference \p hTexRef. + * Any previous address or CUDA array state associated with the texture reference + * is superseded by this function. \p Flags must be set to ::CU_TRSA_OVERRIDE_FORMAT. + * Any CUDA array previously bound to \p hTexRef is unbound. + * + * \param hTexRef - Texture reference to bind + * \param hMipmappedArray - Mipmapped array to bind + * \param Flags - Options (must be ::CU_TRSA_OVERRIDE_FORMAT) + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmappedArray(CUtexref hTexRef, CUmipmappedArray hMipmappedArray, unsigned int Flags); + +/** + * \brief Binds an address as a texture reference + * + * \deprecated + * + * Binds a linear address range to the texture reference \p hTexRef. Any + * previous address or CUDA array state associated with the texture reference + * is superseded by this function. Any memory previously bound to \p hTexRef + * is unbound. + * + * Since the hardware enforces an alignment requirement on texture base + * addresses, ::cuTexRefSetAddress() passes back a byte offset in + * \p *ByteOffset that must be applied to texture fetches in order to read from + * the desired memory. This offset must be divided by the texel size and + * passed to kernels that read from the texture so they can be applied to the + * ::tex1Dfetch() function. + * + * If the device memory pointer was returned from ::cuMemAlloc(), the offset + * is guaranteed to be 0 and NULL may be passed as the \p ByteOffset parameter. + * + * The total number of elements (or texels) in the linear address range + * cannot exceed ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH. + * The number of elements is computed as (\p bytes / bytesPerElement), + * where bytesPerElement is determined from the data format and number of + * components set using ::cuTexRefSetFormat(). + * + * \param ByteOffset - Returned byte offset + * \param hTexRef - Texture reference to bind + * \param dptr - Device pointer to bind + * \param bytes - Size of memory to bind in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTexture + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress(size_t *ByteOffset, CUtexref hTexRef, CUdeviceptr dptr, size_t bytes); + +/** + * \brief Binds an address as a 2D texture reference + * + * \deprecated + * + * Binds a linear address range to the texture reference \p hTexRef. Any + * previous address or CUDA array state associated with the texture reference + * is superseded by this function. Any memory previously bound to \p hTexRef + * is unbound. + * + * Using a ::tex2D() function inside a kernel requires a call to either + * ::cuTexRefSetArray() to bind the corresponding texture reference to an + * array, or ::cuTexRefSetAddress2D() to bind the texture reference to linear + * memory. + * + * Function calls to ::cuTexRefSetFormat() cannot follow calls to + * ::cuTexRefSetAddress2D() for the same texture reference. + * + * It is required that \p dptr be aligned to the appropriate hardware-specific + * texture alignment. You can query this value using the device attribute + * ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT. If an unaligned \p dptr is + * supplied, ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \p Pitch has to be aligned to the hardware-specific texture pitch alignment. + * This value can be queried using the device attribute + * ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT. If an unaligned \p Pitch is + * supplied, ::CUDA_ERROR_INVALID_VALUE is returned. + * + * Width and Height, which are specified in elements (or texels), cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH and + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT respectively. + * \p Pitch, which is specified in bytes, cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH. + * + * \param hTexRef - Texture reference to bind + * \param desc - Descriptor of CUDA array + * \param dptr - Device pointer to bind + * \param Pitch - Line pitch in bytes + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTexture2D + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); + +/** + * \brief Sets the format for a texture reference + * + * \deprecated + * + * Specifies the format of the data to be read by the texture reference + * \p hTexRef. \p fmt and \p NumPackedComponents are exactly analogous to the + * ::Format and ::NumChannels members of the ::CUDA_ARRAY_DESCRIPTOR structure: + * They specify the format of each component and the number of components per + * array element. + * + * \param hTexRef - Texture reference + * \param fmt - Format to set + * \param NumPackedComponents - Number of components per array element + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaCreateChannelDesc, + * ::cudaBindTexture, + * ::cudaBindTexture2D, + * ::cudaBindTextureToArray, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFormat(CUtexref hTexRef, CUarray_format fmt, int NumPackedComponents); + +/** + * \brief Sets the addressing mode for a texture reference + * + * \deprecated + * + * Specifies the addressing mode \p am for the given dimension \p dim of the + * texture reference \p hTexRef. If \p dim is zero, the addressing mode is + * applied to the first parameter of the functions used to fetch from the + * texture; if \p dim is 1, the second, and so on. ::CUaddress_mode is defined + * as: + * \code + typedef enum CUaddress_mode_enum { + CU_TR_ADDRESS_MODE_WRAP = 0, + CU_TR_ADDRESS_MODE_CLAMP = 1, + CU_TR_ADDRESS_MODE_MIRROR = 2, + CU_TR_ADDRESS_MODE_BORDER = 3 + } CUaddress_mode; + * \endcode + * + * Note that this call has no effect if \p hTexRef is bound to linear memory. + * Also, if the flag, ::CU_TRSF_NORMALIZED_COORDINATES, is not set, the only + * supported address mode is ::CU_TR_ADDRESS_MODE_CLAMP. + * + * \param hTexRef - Texture reference + * \param dim - Dimension + * \param am - Addressing mode to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTexture, + * ::cudaBindTexture2D, + * ::cudaBindTextureToArray, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetAddressMode(CUtexref hTexRef, int dim, CUaddress_mode am); + +/** + * \brief Sets the filtering mode for a texture reference + * + * \deprecated + * + * Specifies the filtering mode \p fm to be used when reading memory through + * the texture reference \p hTexRef. ::CUfilter_mode_enum is defined as: + * + * \code + typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, + CU_TR_FILTER_MODE_LINEAR = 1 + } CUfilter_mode; + * \endcode + * + * Note that this call has no effect if \p hTexRef is bound to linear memory. + * + * \param hTexRef - Texture reference + * \param fm - Filtering mode to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFilterMode(CUtexref hTexRef, CUfilter_mode fm); + +/** + * \brief Sets the mipmap filtering mode for a texture reference + * + * \deprecated + * + * Specifies the mipmap filtering mode \p fm to be used when reading memory through + * the texture reference \p hTexRef. ::CUfilter_mode_enum is defined as: + * + * \code + typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, + CU_TR_FILTER_MODE_LINEAR = 1 + } CUfilter_mode; + * \endcode + * + * Note that this call has no effect if \p hTexRef is not bound to a mipmapped array. + * + * \param hTexRef - Texture reference + * \param fm - Filtering mode to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapFilterMode(CUtexref hTexRef, CUfilter_mode fm); + +/** + * \brief Sets the mipmap level bias for a texture reference + * + * \deprecated + * + * Specifies the mipmap level bias \p bias to be added to the specified mipmap level when + * reading memory through the texture reference \p hTexRef. + * + * Note that this call has no effect if \p hTexRef is not bound to a mipmapped array. + * + * \param hTexRef - Texture reference + * \param bias - Mipmap level bias + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelBias(CUtexref hTexRef, float bias); + +/** + * \brief Sets the mipmap min/max mipmap level clamps for a texture reference + * + * \deprecated + * + * Specifies the min/max mipmap level clamps, \p minMipmapLevelClamp and \p maxMipmapLevelClamp + * respectively, to be used when reading memory through the texture reference + * \p hTexRef. + * + * Note that this call has no effect if \p hTexRef is not bound to a mipmapped array. + * + * \param hTexRef - Texture reference + * \param minMipmapLevelClamp - Mipmap min level clamp + * \param maxMipmapLevelClamp - Mipmap max level clamp + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMipmapLevelClamp(CUtexref hTexRef, float minMipmapLevelClamp, float maxMipmapLevelClamp); + +/** + * \brief Sets the maximum anisotropy for a texture reference + * + * \deprecated + * + * Specifies the maximum anisotropy \p maxAniso to be used when reading memory through + * the texture reference \p hTexRef. + * + * Note that this call has no effect if \p hTexRef is bound to linear memory. + * + * \param hTexRef - Texture reference + * \param maxAniso - Maximum anisotropy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTextureToArray, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetMaxAnisotropy(CUtexref hTexRef, unsigned int maxAniso); + +/** + * \brief Sets the border color for a texture reference + * + * \deprecated + * + * Specifies the value of the RGBA color via the \p pBorderColor to the texture reference + * \p hTexRef. The color value supports only float type and holds color components in + * the following sequence: + * pBorderColor[0] holds 'R' component + * pBorderColor[1] holds 'G' component + * pBorderColor[2] holds 'B' component + * pBorderColor[3] holds 'A' component + * + * Note that the color values can be set only when the Address mode is set to + * CU_TR_ADDRESS_MODE_BORDER using ::cuTexRefSetAddressMode. + * Applications using integer border color values have to "reinterpret_cast" their values to float. + * + * \param hTexRef - Texture reference + * \param pBorderColor - RGBA color + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddressMode, + * ::cuTexRefGetAddressMode, ::cuTexRefGetBorderColor, + * ::cudaBindTexture, + * ::cudaBindTexture2D, + * ::cudaBindTextureToArray, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetBorderColor(CUtexref hTexRef, float *pBorderColor); + +/** + * \brief Sets the flags for a texture reference + * + * \deprecated + * + * Specifies optional flags via \p Flags to specify the behavior of data + * returned through the texture reference \p hTexRef. The valid flags are: + * + * - ::CU_TRSF_READ_AS_INTEGER, which suppresses the default behavior of + * having the texture promote integer data to floating point data in the + * range [0, 1]. Note that texture with 32-bit integer format + * would not be promoted, regardless of whether or not this + * flag is specified; + * - ::CU_TRSF_NORMALIZED_COORDINATES, which suppresses the + * default behavior of having the texture coordinates range + * from [0, Dim) where Dim is the width or height of the CUDA + * array. Instead, the texture coordinates [0, 1.0) reference + * the entire breadth of the array dimension; + * - ::CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, which disables any trilinear + * filtering optimizations. Trilinear optimizations improve texture filtering + * performance by allowing bilinear filtering on textures in scenarios where + * it can closely approximate the expected results. + * + * \param hTexRef - Texture reference + * \param Flags - Optional flags to set + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat, + * ::cudaBindTexture, + * ::cudaBindTexture2D, + * ::cudaBindTextureToArray, + * ::cudaBindTextureToMipmappedArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefSetFlags(CUtexref hTexRef, unsigned int Flags); + +/** + * \brief Gets the address associated with a texture reference + * + * \deprecated + * + * Returns in \p *pdptr the base address bound to the texture reference + * \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference + * is not bound to any device memory range. + * + * \param pdptr - Returned device address + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr *pdptr, CUtexref hTexRef); + +/** + * \brief Gets the array bound to a texture reference + * + * \deprecated + * + * Returns in \p *phArray the CUDA array bound to the texture reference + * \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference + * is not bound to any CUDA array. + * + * \param phArray - Returned array + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetArray(CUarray *phArray, CUtexref hTexRef); + +/** + * \brief Gets the mipmapped array bound to a texture reference + * + * \deprecated + * + * Returns in \p *phMipmappedArray the CUDA mipmapped array bound to the texture + * reference \p hTexRef, or returns ::CUDA_ERROR_INVALID_VALUE if the texture reference + * is not bound to any CUDA mipmapped array. + * + * \param phMipmappedArray - Returned mipmapped array + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmappedArray(CUmipmappedArray *phMipmappedArray, CUtexref hTexRef); + +/** + * \brief Gets the addressing mode used by a texture reference + * + * \deprecated + * + * Returns in \p *pam the addressing mode corresponding to the + * dimension \p dim of the texture reference \p hTexRef. Currently, the only + * valid value for \p dim are 0 and 1. + * + * \param pam - Returned addressing mode + * \param hTexRef - Texture reference + * \param dim - Dimension + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetAddressMode(CUaddress_mode *pam, CUtexref hTexRef, int dim); + +/** + * \brief Gets the filter-mode used by a texture reference + * + * \deprecated + * + * Returns in \p *pfm the filtering mode of the texture reference + * \p hTexRef. + * + * \param pfm - Returned filtering mode + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); + +/** + * \brief Gets the format used by a texture reference + * + * \deprecated + * + * Returns in \p *pFormat and \p *pNumChannels the format and number + * of components of the CUDA array bound to the texture reference \p hTexRef. + * If \p pFormat or \p pNumChannels is NULL, it will be ignored. + * + * \param pFormat - Returned format + * \param pNumChannels - Returned number of components + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFormat(CUarray_format *pFormat, int *pNumChannels, CUtexref hTexRef); + +/** + * \brief Gets the mipmap filtering mode for a texture reference + * + * \deprecated + * + * Returns the mipmap filtering mode in \p pfm that's used when reading memory through + * the texture reference \p hTexRef. + * + * \param pfm - Returned mipmap filtering mode + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapFilterMode(CUfilter_mode *pfm, CUtexref hTexRef); + +/** + * \brief Gets the mipmap level bias for a texture reference + * + * \deprecated + * + * Returns the mipmap level bias in \p pBias that's added to the specified mipmap + * level when reading memory through the texture reference \p hTexRef. + * + * \param pbias - Returned mipmap level bias + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapLevelBias(float *pbias, CUtexref hTexRef); + +/** + * \brief Gets the min/max mipmap level clamps for a texture reference + * + * \deprecated + * + * Returns the min/max mipmap level clamps in \p pminMipmapLevelClamp and \p pmaxMipmapLevelClamp + * that's used when reading memory through the texture reference \p hTexRef. + * + * \param pminMipmapLevelClamp - Returned mipmap min level clamp + * \param pmaxMipmapLevelClamp - Returned mipmap max level clamp + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMipmapLevelClamp(float *pminMipmapLevelClamp, float *pmaxMipmapLevelClamp, CUtexref hTexRef); + +/** + * \brief Gets the maximum anisotropy for a texture reference + * + * \deprecated + * + * Returns the maximum anisotropy in \p pmaxAniso that's used when reading memory through + * the texture reference \p hTexRef. + * + * \param pmaxAniso - Returned maximum anisotropy + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFlags, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetMaxAnisotropy(int *pmaxAniso, CUtexref hTexRef); + +/** + * \brief Gets the border color used by a texture reference + * + * \deprecated + * + * Returns in \p pBorderColor, values of the RGBA color used by + * the texture reference \p hTexRef. + * The color value is of type float and holds color components in + * the following sequence: + * pBorderColor[0] holds 'R' component + * pBorderColor[1] holds 'G' component + * pBorderColor[2] holds 'B' component + * pBorderColor[3] holds 'A' component + * + * \param hTexRef - Texture reference + * \param pBorderColor - Returned Type and Value of RGBA color + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddressMode, + * ::cuTexRefSetAddressMode, ::cuTexRefSetBorderColor + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetBorderColor(float *pBorderColor, CUtexref hTexRef); + +/** + * \brief Gets the flags used by a texture reference + * + * \deprecated + * + * Returns in \p *pFlags the flags of the texture reference \p hTexRef. + * + * \param pFlags - Returned flags + * \param hTexRef - Texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefSetAddress, + * ::cuTexRefSetAddress2D, ::cuTexRefSetAddressMode, ::cuTexRefSetArray, + * ::cuTexRefSetFilterMode, ::cuTexRefSetFlags, ::cuTexRefSetFormat, + * ::cuTexRefGetAddress, ::cuTexRefGetAddressMode, ::cuTexRefGetArray, + * ::cuTexRefGetFilterMode, ::cuTexRefGetFormat + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefGetFlags(unsigned int *pFlags, CUtexref hTexRef); + +/** + * \brief Creates a texture reference + * + * \deprecated + * + * Creates a texture reference and returns its handle in \p *pTexRef. Once + * created, the application must call ::cuTexRefSetArray() or + * ::cuTexRefSetAddress() to associate the reference with allocated memory. + * Other texture reference functions are used to specify the format and + * interpretation (addressing, filtering, etc.) to be used when the memory is + * read through this texture reference. + * + * \param pTexRef - Returned texture reference + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefDestroy + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefCreate(CUtexref *pTexRef); + +/** + * \brief Destroys a texture reference + * + * \deprecated + * + * Destroys the texture reference specified by \p hTexRef. + * + * \param hTexRef - Texture reference to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuTexRefCreate + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuTexRefDestroy(CUtexref hTexRef); + +/** @} */ /* END CUDA_TEXREF_DEPRECATED */ + + +/** + * \defgroup CUDA_SURFREF_DEPRECATED Surface Reference Management [DEPRECATED] + * + * ___MANBRIEF___ surface reference management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the surface reference management functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Sets the CUDA array for a surface reference. + * + * \deprecated + * + * Sets the CUDA array \p hArray to be read and written by the surface reference + * \p hSurfRef. Any previous CUDA array state associated with the surface + * reference is superseded by this function. \p Flags must be set to 0. + * The ::CUDA_ARRAY3D_SURFACE_LDST flag must have been set for the CUDA array. + * Any CUDA array previously bound to \p hSurfRef is unbound. + + * \param hSurfRef - Surface reference handle + * \param hArray - CUDA array handle + * \param Flags - set to 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuModuleGetSurfRef, + * ::cuSurfRefGetArray, + * ::cudaBindSurfaceToArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefSetArray(CUsurfref hSurfRef, CUarray hArray, unsigned int Flags); + +/** + * \brief Passes back the CUDA array bound to a surface reference. + * + * \deprecated + * + * Returns in \p *phArray the CUDA array bound to the surface reference + * \p hSurfRef, or returns ::CUDA_ERROR_INVALID_VALUE if the surface reference + * is not bound to any CUDA array. + + * \param phArray - Surface reference handle + * \param hSurfRef - Surface reference handle + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa ::cuModuleGetSurfRef, ::cuSurfRefSetArray + */ +__CUDA_DEPRECATED CUresult CUDAAPI cuSurfRefGetArray(CUarray *phArray, CUsurfref hSurfRef); + +/** @} */ /* END CUDA_SURFREF_DEPRECATED */ + +/** + * \defgroup CUDA_TEXOBJECT Texture Object Management + * + * ___MANBRIEF___ texture object management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the texture object management functions of the + * low-level CUDA driver application programming interface. The texture + * object API is only supported on devices of compute capability 3.0 or higher. + * + * @{ + */ + +/** + * \brief Creates a texture object + * + * Creates a texture object and returns it in \p pTexObject. \p pResDesc describes + * the data to texture from. \p pTexDesc describes how the data should be sampled. + * \p pResViewDesc is an optional argument that specifies an alternate format for + * the data described by \p pResDesc, and also describes the subresource region + * to restrict access to when texturing. \p pResViewDesc can only be specified if + * the type of resource is a CUDA array or a CUDA mipmapped array. + * + * Texture objects are only supported on devices of compute capability 3.0 or higher. + * Additionally, a texture object is an opaque value, and, as such, should only be + * accessed through CUDA API calls. + * + * The ::CUDA_RESOURCE_DESC structure is defined as: + * \code + typedef struct CUDA_RESOURCE_DESC_st + { + CUresourcetype resType; + + union { + struct { + CUarray hArray; + } array; + struct { + CUmipmappedArray hMipmappedArray; + } mipmap; + struct { + CUdeviceptr devPtr; + CUarray_format format; + unsigned int numChannels; + size_t sizeInBytes; + } linear; + struct { + CUdeviceptr devPtr; + CUarray_format format; + unsigned int numChannels; + size_t width; + size_t height; + size_t pitchInBytes; + } pitch2D; + } res; + + unsigned int flags; + } CUDA_RESOURCE_DESC; + + * \endcode + * where: + * - ::CUDA_RESOURCE_DESC::resType specifies the type of resource to texture from. + * CUresourceType is defined as: + * \code + typedef enum CUresourcetype_enum { + CU_RESOURCE_TYPE_ARRAY = 0x00, + CU_RESOURCE_TYPE_MIPMAPPED_ARRAY = 0x01, + CU_RESOURCE_TYPE_LINEAR = 0x02, + CU_RESOURCE_TYPE_PITCH2D = 0x03 + } CUresourcetype; + * \endcode + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_ARRAY, ::CUDA_RESOURCE_DESC::res::array::hArray + * must be set to a valid CUDA array handle. + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_MIPMAPPED_ARRAY, ::CUDA_RESOURCE_DESC::res::mipmap::hMipmappedArray + * must be set to a valid CUDA mipmapped array handle. + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_LINEAR, ::CUDA_RESOURCE_DESC::res::linear::devPtr + * must be set to a valid device pointer, that is aligned to ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT. + * ::CUDA_RESOURCE_DESC::res::linear::format and ::CUDA_RESOURCE_DESC::res::linear::numChannels + * describe the format of each component and the number of components per array element. ::CUDA_RESOURCE_DESC::res::linear::sizeInBytes + * specifies the size of the array in bytes. The total number of elements in the linear address range cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE1D_LINEAR_WIDTH. The number of elements is computed as (sizeInBytes / (sizeof(format) * numChannels)). + * + * \par + * If ::CUDA_RESOURCE_DESC::resType is set to ::CU_RESOURCE_TYPE_PITCH2D, ::CUDA_RESOURCE_DESC::res::pitch2D::devPtr + * must be set to a valid device pointer, that is aligned to ::CU_DEVICE_ATTRIBUTE_TEXTURE_ALIGNMENT. + * ::CUDA_RESOURCE_DESC::res::pitch2D::format and ::CUDA_RESOURCE_DESC::res::pitch2D::numChannels + * describe the format of each component and the number of components per array element. ::CUDA_RESOURCE_DESC::res::pitch2D::width + * and ::CUDA_RESOURCE_DESC::res::pitch2D::height specify the width and height of the array in elements, and cannot exceed + * ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_WIDTH and ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_HEIGHT respectively. + * ::CUDA_RESOURCE_DESC::res::pitch2D::pitchInBytes specifies the pitch between two rows in bytes and has to be aligned to + * ::CU_DEVICE_ATTRIBUTE_TEXTURE_PITCH_ALIGNMENT. Pitch cannot exceed ::CU_DEVICE_ATTRIBUTE_MAXIMUM_TEXTURE2D_LINEAR_PITCH. + * + * - ::flags must be set to zero. + * + * + * The ::CUDA_TEXTURE_DESC struct is defined as + * \code + typedef struct CUDA_TEXTURE_DESC_st { + CUaddress_mode addressMode[3]; + CUfilter_mode filterMode; + unsigned int flags; + unsigned int maxAnisotropy; + CUfilter_mode mipmapFilterMode; + float mipmapLevelBias; + float minMipmapLevelClamp; + float maxMipmapLevelClamp; + } CUDA_TEXTURE_DESC; + * \endcode + * where + * - ::CUDA_TEXTURE_DESC::addressMode specifies the addressing mode for each dimension of the texture data. ::CUaddress_mode is defined as: + * \code + typedef enum CUaddress_mode_enum { + CU_TR_ADDRESS_MODE_WRAP = 0, + CU_TR_ADDRESS_MODE_CLAMP = 1, + CU_TR_ADDRESS_MODE_MIRROR = 2, + CU_TR_ADDRESS_MODE_BORDER = 3 + } CUaddress_mode; + * \endcode + * This is ignored if ::CUDA_RESOURCE_DESC::resType is ::CU_RESOURCE_TYPE_LINEAR. Also, if the flag, ::CU_TRSF_NORMALIZED_COORDINATES + * is not set, the only supported address mode is ::CU_TR_ADDRESS_MODE_CLAMP. + * + * - ::CUDA_TEXTURE_DESC::filterMode specifies the filtering mode to be used when fetching from the texture. CUfilter_mode is defined as: + * \code + typedef enum CUfilter_mode_enum { + CU_TR_FILTER_MODE_POINT = 0, + CU_TR_FILTER_MODE_LINEAR = 1 + } CUfilter_mode; + * \endcode + * This is ignored if ::CUDA_RESOURCE_DESC::resType is ::CU_RESOURCE_TYPE_LINEAR. + * + * - ::CUDA_TEXTURE_DESC::flags can be any combination of the following: + * - ::CU_TRSF_READ_AS_INTEGER, which suppresses the default behavior of + * having the texture promote integer data to floating point data in the + * range [0, 1]. Note that texture with 32-bit integer format would not be + * promoted, regardless of whether or not this flag is specified. + * - ::CU_TRSF_NORMALIZED_COORDINATES, which suppresses the default behavior + * of having the texture coordinates range from [0, Dim) where Dim is the + * width or height of the CUDA array. Instead, the texture coordinates + * [0, 1.0) reference the entire breadth of the array dimension; Note that + * for CUDA mipmapped arrays, this flag has to be set. + * - ::CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, which disables any trilinear + * filtering optimizations. Trilinear optimizations improve texture filtering + * performance by allowing bilinear filtering on textures in scenarios where + * it can closely approximate the expected results. + * - ::CU_TRSF_SEAMLESS_CUBEMAP, which enables seamless cube map filtering. + * This flag can only be specified if the underlying resource is a CUDA array + * or a CUDA mipmapped array that was created with the flag ::CUDA_ARRAY3D_CUBEMAP. + * When seamless cube map filtering is enabled, texture address modes specified + * by ::CUDA_TEXTURE_DESC::addressMode are ignored. Instead, if the ::CUDA_TEXTURE_DESC::filterMode + * is set to ::CU_TR_FILTER_MODE_POINT the address mode ::CU_TR_ADDRESS_MODE_CLAMP + * will be applied for all dimensions. If the ::CUDA_TEXTURE_DESC::filterMode is + * set to ::CU_TR_FILTER_MODE_LINEAR seamless cube map filtering will be performed + * when sampling along the cube face borders. + * + * - ::CUDA_TEXTURE_DESC::maxAnisotropy specifies the maximum anisotropy ratio to be used when doing anisotropic filtering. This value will be + * clamped to the range [1,16]. + * + * - ::CUDA_TEXTURE_DESC::mipmapFilterMode specifies the filter mode when the calculated mipmap level lies between two defined mipmap levels. + * + * - ::CUDA_TEXTURE_DESC::mipmapLevelBias specifies the offset to be applied to the calculated mipmap level. + * + * - ::CUDA_TEXTURE_DESC::minMipmapLevelClamp specifies the lower end of the mipmap level range to clamp access to. + * + * - ::CUDA_TEXTURE_DESC::maxMipmapLevelClamp specifies the upper end of the mipmap level range to clamp access to. + * + * + * The ::CUDA_RESOURCE_VIEW_DESC struct is defined as + * \code + typedef struct CUDA_RESOURCE_VIEW_DESC_st + { + CUresourceViewFormat format; + size_t width; + size_t height; + size_t depth; + unsigned int firstMipmapLevel; + unsigned int lastMipmapLevel; + unsigned int firstLayer; + unsigned int lastLayer; + } CUDA_RESOURCE_VIEW_DESC; + * \endcode + * where: + * - ::CUDA_RESOURCE_VIEW_DESC::format specifies how the data contained in the CUDA array or CUDA mipmapped array should + * be interpreted. Note that this can incur a change in size of the texture data. If the resource view format is a block + * compressed format, then the underlying CUDA array or CUDA mipmapped array has to have a base of format ::CU_AD_FORMAT_UNSIGNED_INT32. + * with 2 or 4 channels, depending on the block compressed format. For ex., BC1 and BC4 require the underlying CUDA array to have + * a format of ::CU_AD_FORMAT_UNSIGNED_INT32 with 2 channels. The other BC formats require the underlying resource to have the same base + * format but with 4 channels. + * + * - ::CUDA_RESOURCE_VIEW_DESC::width specifies the new width of the texture data. If the resource view format is a block + * compressed format, this value has to be 4 times the original width of the resource. For non block compressed formats, + * this value has to be equal to that of the original resource. + * + * - ::CUDA_RESOURCE_VIEW_DESC::height specifies the new height of the texture data. If the resource view format is a block + * compressed format, this value has to be 4 times the original height of the resource. For non block compressed formats, + * this value has to be equal to that of the original resource. + * + * - ::CUDA_RESOURCE_VIEW_DESC::depth specifies the new depth of the texture data. This value has to be equal to that of the + * original resource. + * + * - ::CUDA_RESOURCE_VIEW_DESC::firstMipmapLevel specifies the most detailed mipmap level. This will be the new mipmap level zero. + * For non-mipmapped resources, this value has to be zero.::CUDA_TEXTURE_DESC::minMipmapLevelClamp and ::CUDA_TEXTURE_DESC::maxMipmapLevelClamp + * will be relative to this value. For ex., if the firstMipmapLevel is set to 2, and a minMipmapLevelClamp of 1.2 is specified, + * then the actual minimum mipmap level clamp will be 3.2. + * + * - ::CUDA_RESOURCE_VIEW_DESC::lastMipmapLevel specifies the least detailed mipmap level. For non-mipmapped resources, this value + * has to be zero. + * + * - ::CUDA_RESOURCE_VIEW_DESC::firstLayer specifies the first layer index for layered textures. This will be the new layer zero. + * For non-layered resources, this value has to be zero. + * + * - ::CUDA_RESOURCE_VIEW_DESC::lastLayer specifies the last layer index for layered textures. For non-layered resources, + * this value has to be zero. + * + * + * \param pTexObject - Texture object to create + * \param pResDesc - Resource descriptor + * \param pTexDesc - Texture descriptor + * \param pResViewDesc - Resource view descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectDestroy, + * ::cudaCreateTextureObject + */ +CUresult CUDAAPI cuTexObjectCreate(CUtexObject *pTexObject, const CUDA_RESOURCE_DESC *pResDesc, const CUDA_TEXTURE_DESC *pTexDesc, const CUDA_RESOURCE_VIEW_DESC *pResViewDesc); + +/** + * \brief Destroys a texture object + * + * Destroys the texture object specified by \p texObject. + * + * \param texObject - Texture object to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaDestroyTextureObject + */ +CUresult CUDAAPI cuTexObjectDestroy(CUtexObject texObject); + +/** + * \brief Returns a texture object's resource descriptor + * + * Returns the resource descriptor for the texture object specified by \p texObject. + * + * \param pResDesc - Resource descriptor + * \param texObject - Texture object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaGetTextureObjectResourceDesc, + */ +CUresult CUDAAPI cuTexObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, CUtexObject texObject); + +/** + * \brief Returns a texture object's texture descriptor + * + * Returns the texture descriptor for the texture object specified by \p texObject. + * + * \param pTexDesc - Texture descriptor + * \param texObject - Texture object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaGetTextureObjectTextureDesc + */ +CUresult CUDAAPI cuTexObjectGetTextureDesc(CUDA_TEXTURE_DESC *pTexDesc, CUtexObject texObject); + +/** + * \brief Returns a texture object's resource view descriptor + * + * Returns the resource view descriptor for the texture object specified by \p texObject. + * If no resource view was set for \p texObject, the ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param pResViewDesc - Resource view descriptor + * \param texObject - Texture object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuTexObjectCreate, + * ::cudaGetTextureObjectResourceViewDesc + */ +CUresult CUDAAPI cuTexObjectGetResourceViewDesc(CUDA_RESOURCE_VIEW_DESC *pResViewDesc, CUtexObject texObject); + +/** @} */ /* END CUDA_TEXOBJECT */ + +/** + * \defgroup CUDA_SURFOBJECT Surface Object Management + * + * ___MANBRIEF___ surface object management functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the surface object management functions of the + * low-level CUDA driver application programming interface. The surface + * object API is only supported on devices of compute capability 3.0 or higher. + * + * @{ + */ + +/** + * \brief Creates a surface object + * + * Creates a surface object and returns it in \p pSurfObject. \p pResDesc describes + * the data to perform surface load/stores on. ::CUDA_RESOURCE_DESC::resType must be + * ::CU_RESOURCE_TYPE_ARRAY and ::CUDA_RESOURCE_DESC::res::array::hArray + * must be set to a valid CUDA array handle. ::CUDA_RESOURCE_DESC::flags must be set to zero. + * + * Surface objects are only supported on devices of compute capability 3.0 or higher. + * Additionally, a surface object is an opaque value, and, as such, should only be + * accessed through CUDA API calls. + * + * \param pSurfObject - Surface object to create + * \param pResDesc - Resource descriptor + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuSurfObjectDestroy, + * ::cudaCreateSurfaceObject + */ +CUresult CUDAAPI cuSurfObjectCreate(CUsurfObject *pSurfObject, const CUDA_RESOURCE_DESC *pResDesc); + +/** + * \brief Destroys a surface object + * + * Destroys the surface object specified by \p surfObject. + * + * \param surfObject - Surface object to destroy + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuSurfObjectCreate, + * ::cudaDestroySurfaceObject + */ +CUresult CUDAAPI cuSurfObjectDestroy(CUsurfObject surfObject); + +/** + * \brief Returns a surface object's resource descriptor + * + * Returns the resource descriptor for the surface object specified by \p surfObject. + * + * \param pResDesc - Resource descriptor + * \param surfObject - Surface object + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE + * + * \sa + * ::cuSurfObjectCreate, + * ::cudaGetSurfaceObjectResourceDesc + */ +CUresult CUDAAPI cuSurfObjectGetResourceDesc(CUDA_RESOURCE_DESC *pResDesc, CUsurfObject surfObject); + +/** @} */ /* END CUDA_SURFOBJECT */ + +/** + * \defgroup CUDA_PEER_ACCESS Peer Context Memory Access + * + * ___MANBRIEF___ direct peer context memory access functions of the low-level + * CUDA driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the direct peer context memory access functions + * of the low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Queries if a device may directly access a peer device's memory. + * + * Returns in \p *canAccessPeer a value of 1 if contexts on \p dev are capable of + * directly accessing memory from contexts on \p peerDev and 0 otherwise. + * If direct access of \p peerDev from \p dev is possible, then access may be + * enabled on two specific contexts by calling ::cuCtxEnablePeerAccess(). + * + * \param canAccessPeer - Returned access capability + * \param dev - Device from which allocations on \p peerDev are to + * be directly accessed. + * \param peerDev - Device on which the allocations to be directly accessed + * by \p dev reside. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE + * \notefnerr + * + * \sa + * ::cuCtxEnablePeerAccess, + * ::cuCtxDisablePeerAccess, + * ::cudaDeviceCanAccessPeer + */ +CUresult CUDAAPI cuDeviceCanAccessPeer(int *canAccessPeer, CUdevice dev, CUdevice peerDev); + +/** + * \brief Enables direct access to memory allocations in a peer context. + * + * If both the current context and \p peerContext are on devices which support unified + * addressing (as may be queried using ::CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING) and same + * major compute capability, then on success all allocations from \p peerContext will + * immediately be accessible by the current context. See \ref CUDA_UNIFIED for additional + * details. + * + * Note that access granted by this call is unidirectional and that in order to access + * memory from the current context in \p peerContext, a separate symmetric call + * to ::cuCtxEnablePeerAccess() is required. + * + * Note that there are both device-wide and system-wide limitations per system + * configuration, as noted in the CUDA Programming Guide under the section + * "Peer-to-Peer Memory Access". + * + * Returns ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED if ::cuDeviceCanAccessPeer() indicates + * that the ::CUdevice of the current context cannot directly access memory + * from the ::CUdevice of \p peerContext. + * + * Returns ::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED if direct access of + * \p peerContext from the current context has already been enabled. + * + * Returns ::CUDA_ERROR_TOO_MANY_PEERS if direct peer access is not possible + * because hardware resources required for peer access have been exhausted. + * + * Returns ::CUDA_ERROR_INVALID_CONTEXT if there is no current context, \p peerContext + * is not a valid context, or if the current context is \p peerContext. + * + * Returns ::CUDA_ERROR_INVALID_VALUE if \p Flags is not 0. + * + * \param peerContext - Peer context to enable direct access to from the current context + * \param Flags - Reserved for future use and must be set to 0 + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_PEER_ACCESS_ALREADY_ENABLED, + * ::CUDA_ERROR_TOO_MANY_PEERS, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_PEER_ACCESS_UNSUPPORTED, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuDeviceCanAccessPeer, + * ::cuCtxDisablePeerAccess, + * ::cudaDeviceEnablePeerAccess + */ +CUresult CUDAAPI cuCtxEnablePeerAccess(CUcontext peerContext, unsigned int Flags); + +/** + * \brief Disables direct access to memory allocations in a peer context and + * unregisters any registered allocations. + * + Returns ::CUDA_ERROR_PEER_ACCESS_NOT_ENABLED if direct peer access has + * not yet been enabled from \p peerContext to the current context. + * + * Returns ::CUDA_ERROR_INVALID_CONTEXT if there is no current context, or if + * \p peerContext is not a valid context. + * + * \param peerContext - Peer context to disable direct access to + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_PEER_ACCESS_NOT_ENABLED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * \notefnerr + * + * \sa + * ::cuDeviceCanAccessPeer, + * ::cuCtxEnablePeerAccess, + * ::cudaDeviceDisablePeerAccess + */ +CUresult CUDAAPI cuCtxDisablePeerAccess(CUcontext peerContext); + +/** + * \brief Queries attributes of the link between two devices. + * + * Returns in \p *value the value of the requested attribute \p attrib of the + * link between \p srcDevice and \p dstDevice. The supported attributes are: + * - ::CU_DEVICE_P2P_ATTRIBUTE_PERFORMANCE_RANK: A relative value indicating the + * performance of the link between two devices. + * - ::CU_DEVICE_P2P_ATTRIBUTE_ACCESS_SUPPORTED P2P: 1 if P2P Access is enable. + * - ::CU_DEVICE_P2P_ATTRIBUTE_NATIVE_ATOMIC_SUPPORTED: 1 if Atomic operations over + * the link are supported. + * - ::CU_DEVICE_P2P_ATTRIBUTE_CUDA_ARRAY_ACCESS_SUPPORTED: 1 if cudaArray can + * be accessed over the link. + * + * Returns ::CUDA_ERROR_INVALID_DEVICE if \p srcDevice or \p dstDevice are not valid + * or if they represent the same device. + * + * Returns ::CUDA_ERROR_INVALID_VALUE if \p attrib is not valid or if \p value is + * a null pointer. + * + * \param value - Returned value of the requested attribute + * \param attrib - The requested attribute of the link between \p srcDevice and \p dstDevice. + * \param srcDevice - The source device of the target link. + * \param dstDevice - The destination device of the target link. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_DEVICE, + * ::CUDA_ERROR_INVALID_VALUE + * \notefnerr + * + * \sa + * ::cuCtxEnablePeerAccess, + * ::cuCtxDisablePeerAccess, + * ::cuDeviceCanAccessPeer, + * ::cudaDeviceGetP2PAttribute + */ +CUresult CUDAAPI cuDeviceGetP2PAttribute(int* value, CUdevice_P2PAttribute attrib, CUdevice srcDevice, CUdevice dstDevice); + +/** @} */ /* END CUDA_PEER_ACCESS */ + +/** + * \defgroup CUDA_GRAPHICS Graphics Interoperability + * + * ___MANBRIEF___ graphics interoperability functions of the low-level CUDA + * driver API (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the graphics interoperability functions of the + * low-level CUDA driver application programming interface. + * + * @{ + */ + +/** + * \brief Unregisters a graphics resource for access by CUDA + * + * Unregisters the graphics resource \p resource so it is not accessible by + * CUDA unless registered again. + * + * If \p resource is invalid then ::CUDA_ERROR_INVALID_HANDLE is + * returned. + * + * \param resource - Resource to unregister + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_UNKNOWN + * \notefnerr + * + * \sa + * ::cuGraphicsD3D9RegisterResource, + * ::cuGraphicsD3D10RegisterResource, + * ::cuGraphicsD3D11RegisterResource, + * ::cuGraphicsGLRegisterBuffer, + * ::cuGraphicsGLRegisterImage, + * ::cudaGraphicsUnregisterResource + */ +CUresult CUDAAPI cuGraphicsUnregisterResource(CUgraphicsResource resource); + +/** + * \brief Get an array through which to access a subresource of a mapped graphics resource. + * + * Returns in \p *pArray an array through which the subresource of the mapped + * graphics resource \p resource which corresponds to array index \p arrayIndex + * and mipmap level \p mipLevel may be accessed. The value set in \p *pArray may + * change every time that \p resource is mapped. + * + * If \p resource is not a texture then it cannot be accessed via an array and + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY is returned. + * If \p arrayIndex is not a valid array index for \p resource then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * If \p mipLevel is not a valid mipmap level for \p resource then + * ::CUDA_ERROR_INVALID_VALUE is returned. + * If \p resource is not mapped then ::CUDA_ERROR_NOT_MAPPED is returned. + * + * \param pArray - Returned array through which a subresource of \p resource may be accessed + * \param resource - Mapped resource to access + * \param arrayIndex - Array index for array textures or cubemap face + * index as defined by ::CUarray_cubemap_face for + * cubemap textures for the subresource to access + * \param mipLevel - Mipmap level for the subresource to access + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY + * \notefnerr + * + * \sa + * ::cuGraphicsResourceGetMappedPointer, + * ::cudaGraphicsSubResourceGetMappedArray + */ +CUresult CUDAAPI cuGraphicsSubResourceGetMappedArray(CUarray *pArray, CUgraphicsResource resource, unsigned int arrayIndex, unsigned int mipLevel); + +/** + * \brief Get a mipmapped array through which to access a mapped graphics resource. + * + * Returns in \p *pMipmappedArray a mipmapped array through which the mapped graphics + * resource \p resource. The value set in \p *pMipmappedArray may change every time + * that \p resource is mapped. + * + * If \p resource is not a texture then it cannot be accessed via a mipmapped array and + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY is returned. + * If \p resource is not mapped then ::CUDA_ERROR_NOT_MAPPED is returned. + * + * \param pMipmappedArray - Returned mipmapped array through which \p resource may be accessed + * \param resource - Mapped resource to access + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_NOT_MAPPED_AS_ARRAY + * \notefnerr + * + * \sa + * ::cuGraphicsResourceGetMappedPointer, + * ::cudaGraphicsResourceGetMappedMipmappedArray + */ +CUresult CUDAAPI cuGraphicsResourceGetMappedMipmappedArray(CUmipmappedArray *pMipmappedArray, CUgraphicsResource resource); + +/** + * \brief Get a device pointer through which to access a mapped graphics resource. + * + * Returns in \p *pDevPtr a pointer through which the mapped graphics resource + * \p resource may be accessed. + * Returns in \p pSize the size of the memory in bytes which may be accessed from that pointer. + * The value set in \p pPointer may change every time that \p resource is mapped. + * + * If \p resource is not a buffer then it cannot be accessed via a pointer and + * ::CUDA_ERROR_NOT_MAPPED_AS_POINTER is returned. + * If \p resource is not mapped then ::CUDA_ERROR_NOT_MAPPED is returned. + * * + * \param pDevPtr - Returned pointer through which \p resource may be accessed + * \param pSize - Returned size of the buffer accessible starting at \p *pPointer + * \param resource - Mapped resource to access + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_NOT_MAPPED_AS_POINTER + * \notefnerr + * + * \sa + * ::cuGraphicsMapResources, + * ::cuGraphicsSubResourceGetMappedArray, + * ::cudaGraphicsResourceGetMappedPointer + */ +CUresult CUDAAPI cuGraphicsResourceGetMappedPointer(CUdeviceptr *pDevPtr, size_t *pSize, CUgraphicsResource resource); + +/** + * \brief Set usage flags for mapping a graphics resource + * + * Set \p flags for mapping the graphics resource \p resource. + * + * Changes to \p flags will take effect the next time \p resource is mapped. + * The \p flags argument may be any of the following: + + * - ::CU_GRAPHICS_MAP_RESOURCE_FLAGS_NONE: Specifies no hints about how this + * resource will be used. It is therefore assumed that this resource will be + * read from and written to by CUDA kernels. This is the default value. + * - ::CU_GRAPHICS_MAP_RESOURCE_FLAGS_READONLY: Specifies that CUDA kernels which + * access this resource will not write to this resource. + * - ::CU_GRAPHICS_MAP_RESOURCE_FLAGS_WRITEDISCARD: Specifies that CUDA kernels + * which access this resource will not read from this resource and will + * write over the entire contents of the resource, so none of the data + * previously stored in the resource will be preserved. + * + * If \p resource is presently mapped for access by CUDA then + * ::CUDA_ERROR_ALREADY_MAPPED is returned. + * If \p flags is not one of the above values then ::CUDA_ERROR_INVALID_VALUE is returned. + * + * \param resource - Registered resource to set flags for + * \param flags - Parameters for resource mapping + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ALREADY_MAPPED + * \notefnerr + * + * \sa + * ::cuGraphicsMapResources, + * ::cudaGraphicsResourceSetMapFlags + */ +CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, unsigned int flags); + +/** + * \brief Map graphics resources for access by CUDA + * + * Maps the \p count graphics resources in \p resources for access by CUDA. + * + * The resources in \p resources may be accessed by CUDA until they + * are unmapped. The graphics API from which \p resources were registered + * should not access any resources while they are mapped by CUDA. If an + * application does so, the results are undefined. + * + * This function provides the synchronization guarantee that any graphics calls + * issued before ::cuGraphicsMapResources() will complete before any subsequent CUDA + * work issued in \p stream begins. + * + * If \p resources includes any duplicate entries then ::CUDA_ERROR_INVALID_HANDLE is returned. + * If any of \p resources are presently mapped for access by CUDA then ::CUDA_ERROR_ALREADY_MAPPED is returned. + * + * \param count - Number of resources to map + * \param resources - Resources to map for CUDA usage + * \param hStream - Stream with which to synchronize + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_ALREADY_MAPPED, + * ::CUDA_ERROR_UNKNOWN + * \note_null_stream + * \notefnerr + * + * \sa + * ::cuGraphicsResourceGetMappedPointer, + * ::cuGraphicsSubResourceGetMappedArray, + * ::cuGraphicsUnmapResources, + * ::cudaGraphicsMapResources + */ +CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); + +/** + * \brief Unmap graphics resources. + * + * Unmaps the \p count graphics resources in \p resources. + * + * Once unmapped, the resources in \p resources may not be accessed by CUDA + * until they are mapped again. + * + * This function provides the synchronization guarantee that any CUDA work issued + * in \p stream before ::cuGraphicsUnmapResources() will complete before any + * subsequently issued graphics work begins. + * + * + * If \p resources includes any duplicate entries then ::CUDA_ERROR_INVALID_HANDLE is returned. + * If any of \p resources are not presently mapped for access by CUDA then ::CUDA_ERROR_NOT_MAPPED is returned. + * + * \param count - Number of resources to unmap + * \param resources - Resources to unmap + * \param hStream - Stream with which to synchronize + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_DEINITIALIZED, + * ::CUDA_ERROR_NOT_INITIALIZED, + * ::CUDA_ERROR_INVALID_CONTEXT, + * ::CUDA_ERROR_INVALID_HANDLE, + * ::CUDA_ERROR_NOT_MAPPED, + * ::CUDA_ERROR_UNKNOWN + * \note_null_stream + * \notefnerr + * + * \sa + * ::cuGraphicsMapResources, + * ::cudaGraphicsUnmapResources + */ +CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); + +/** @} */ /* END CUDA_GRAPHICS */ + +/** + * \defgroup CUDA_DRIVER_ENTRY_POINT Driver Entry Point Access + * + * ___MANBRIEF___ driver entry point access functions of the low-level CUDA driver API + * (___CURRENT_FILE___) ___ENDMANBRIEF___ + * + * This section describes the driver entry point access functions of the low-level CUDA + * driver application programming interface. + * + * @{ + */ + +/** + * \brief Returns the requested driver API function pointer + * + * Returns in \p **pfn the address of the CUDA driver function for the requested + * CUDA version and flags. + * + * The CUDA version is specified as (1000 * major + 10 * minor), so CUDA 11.2 + * should be specified as 11020. For a requested driver symbol, if the specified + * CUDA version is greater than or equal to the CUDA version in which the driver symbol + * was introduced, this API will return the function pointer to the corresponding + * versioned function. + * + * The pointer returned by the API should be cast to a function pointer matching the + * requested driver function's definition in the API header file. The function pointer + * typedef can be picked up from the corresponding typedefs header file. For example, + * cudaTypedefs.h consists of function pointer typedefs for driver APIs defined in cuda.h. + * + * The API will return ::CUDA_ERROR_NOT_FOUND if the requested driver function is not + * supported on the platform, no ABI compatible driver function exists for the specified + * \p cudaVersion or if the driver symbol is invalid. + * + * The requested flags can be: + * - ::CU_GET_PROC_ADDRESS_DEFAULT: This is the default mode. This is equivalent to + * ::CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM if the code is compiled with + * --default-stream per-thread compilation flag or the macro CUDA_API_PER_THREAD_DEFAULT_STREAM + * is defined; ::CU_GET_PROC_ADDRESS_LEGACY_STREAM otherwise. + * - ::CU_GET_PROC_ADDRESS_LEGACY_STREAM: This will enable the search for all driver symbols + * that match the requested driver symbol name except the corresponding per-thread versions. + * - ::CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM: This will enable the search for all + * driver symbols that match the requested driver symbol name including the per-thread + * versions. If a per-thread version is not found, the API will return the legacy version + * of the driver function. + * + * \param symbol - The base name of the driver API function to look for. As an example, + * for the driver API ::cuMemAlloc_v2, \p symbol would be cuMemAlloc and + * \p cudaVersion would be the ABI compatible CUDA version for the _v2 variant. + * \param pfn - Location to return the function pointer to the requested driver function + * \param cudaVersion - The CUDA version to look for the requested driver symbol + * \param flags - Flags to specify search options. + * + * \return + * ::CUDA_SUCCESS, + * ::CUDA_ERROR_INVALID_VALUE, + * ::CUDA_ERROR_NOT_SUPPORTED, + * ::CUDA_ERROR_NOT_FOUND + * \note_version_mixing + * + * \sa + * ::cudaGetDriverEntryPoint + */ +CUresult CUDAAPI cuGetProcAddress(const char *symbol, void **pfn, int cudaVersion, cuuint64_t flags); + +/** @} */ /* END CUDA_DRIVER_ENTRY_POINT */ + +CUresult CUDAAPI cuGetExportTable(const void **ppExportTable, const CUuuid *pExportTableId); + +/** + * CUDA API versioning support + */ +#if defined(__CUDA_API_VERSION_INTERNAL) + #undef cuMemHostRegister + #undef cuGraphicsResourceSetMapFlags + #undef cuLinkCreate + #undef cuLinkAddData + #undef cuLinkAddFile + #undef cuDeviceTotalMem + #undef cuCtxCreate + #undef cuModuleGetGlobal + #undef cuMemGetInfo + #undef cuMemAlloc + #undef cuMemAllocPitch + #undef cuMemFree + #undef cuMemGetAddressRange + #undef cuMemAllocHost + #undef cuMemHostGetDevicePointer + #undef cuMemcpyHtoD + #undef cuMemcpyDtoH + #undef cuMemcpyDtoD + #undef cuMemcpyDtoA + #undef cuMemcpyAtoD + #undef cuMemcpyHtoA + #undef cuMemcpyAtoH + #undef cuMemcpyAtoA + #undef cuMemcpyHtoAAsync + #undef cuMemcpyAtoHAsync + #undef cuMemcpy2D + #undef cuMemcpy2DUnaligned + #undef cuMemcpy3D + #undef cuMemcpyHtoDAsync + #undef cuMemcpyDtoHAsync + #undef cuMemcpyDtoDAsync + #undef cuMemcpy2DAsync + #undef cuMemcpy3DAsync + #undef cuMemsetD8 + #undef cuMemsetD16 + #undef cuMemsetD32 + #undef cuMemsetD2D8 + #undef cuMemsetD2D16 + #undef cuMemsetD2D32 + #undef cuArrayCreate + #undef cuArrayGetDescriptor + #undef cuArray3DCreate + #undef cuArray3DGetDescriptor + #undef cuTexRefSetAddress + #undef cuTexRefSetAddress2D + #undef cuTexRefGetAddress + #undef cuGraphicsResourceGetMappedPointer + #undef cuCtxDestroy + #undef cuCtxPopCurrent + #undef cuCtxPushCurrent + #undef cuStreamDestroy + #undef cuEventDestroy + #undef cuMemcpy + #undef cuMemcpyAsync + #undef cuMemcpyPeer + #undef cuMemcpyPeerAsync + #undef cuMemcpy3DPeer + #undef cuMemcpy3DPeerAsync + #undef cuMemsetD8Async + #undef cuMemsetD16Async + #undef cuMemsetD32Async + #undef cuMemsetD2D8Async + #undef cuMemsetD2D16Async + #undef cuMemsetD2D32Async + #undef cuStreamGetPriority + #undef cuStreamGetFlags + #undef cuStreamGetCtx + #undef cuStreamWaitEvent + #undef cuStreamAddCallback + #undef cuStreamAttachMemAsync + #undef cuStreamQuery + #undef cuStreamSynchronize + #undef cuEventRecord + #undef cuEventRecordWithFlags + #undef cuLaunchKernel + + + + #undef cuLaunchHostFunc + #undef cuGraphicsMapResources + #undef cuGraphicsUnmapResources + #undef cuStreamWriteValue32 + #undef cuStreamWaitValue32 + #undef cuStreamWriteValue64 + #undef cuStreamWaitValue64 + #undef cuStreamBatchMemOp + #undef cuMemPrefetchAsync + #undef cuLaunchCooperativeKernel + #undef cuSignalExternalSemaphoresAsync + #undef cuWaitExternalSemaphoresAsync + #undef cuStreamBeginCapture + #undef cuStreamEndCapture + #undef cuStreamIsCapturing + #undef cuStreamGetCaptureInfo + #undef cuStreamGetCaptureInfo_v2 + #undef cuGraphUpload + #undef cuGraphLaunch + #undef cuDevicePrimaryCtxRelease + #undef cuDevicePrimaryCtxReset + #undef cuDevicePrimaryCtxSetFlags + #undef cuIpcOpenMemHandle + #undef cuStreamCopyAttributes + #undef cuStreamSetAttribute + #undef cuStreamGetAttribute + #undef cuGraphInstantiate + #undef cuMemMapArrayAsync + #undef cuMemFreeAsync + #undef cuMemAllocAsync + #undef cuMemAllocFromPoolAsync + #undef cuStreamUpdateCaptureDependencies + + CUresult CUDAAPI cuMemHostRegister(void *p, size_t bytesize, unsigned int Flags); + CUresult CUDAAPI cuGraphicsResourceSetMapFlags(CUgraphicsResource resource, unsigned int flags); + CUresult CUDAAPI cuLinkCreate(unsigned int numOptions, CUjit_option *options, void **optionValues, CUlinkState *stateOut); + CUresult CUDAAPI cuLinkAddData(CUlinkState state, CUjitInputType type, void *data, size_t size, const char *name, + unsigned int numOptions, CUjit_option *options, void **optionValues); + CUresult CUDAAPI cuLinkAddFile(CUlinkState state, CUjitInputType type, const char *path, + unsigned int numOptions, CUjit_option *options, void **optionValues); + CUresult CUDAAPI cuTexRefSetAddress2D_v2(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR *desc, CUdeviceptr dptr, size_t Pitch); + + typedef unsigned int CUdeviceptr_v1; + + typedef struct CUDA_MEMCPY2D_v1_st + { + unsigned int srcXInBytes; /**< Source X in bytes */ + unsigned int srcY; /**< Source Y */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr_v1 srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ + + unsigned int dstXInBytes; /**< Destination X in bytes */ + unsigned int dstY; /**< Destination Y */ + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr_v1 dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ + + unsigned int WidthInBytes; /**< Width of 2D memory copy in bytes */ + unsigned int Height; /**< Height of 2D memory copy */ + } CUDA_MEMCPY2D_v1; + + typedef struct CUDA_MEMCPY3D_v1_st + { + unsigned int srcXInBytes; /**< Source X in bytes */ + unsigned int srcY; /**< Source Y */ + unsigned int srcZ; /**< Source Z */ + unsigned int srcLOD; /**< Source LOD */ + CUmemorytype srcMemoryType; /**< Source memory type (host, device, array) */ + const void *srcHost; /**< Source host pointer */ + CUdeviceptr_v1 srcDevice; /**< Source device pointer */ + CUarray srcArray; /**< Source array reference */ + void *reserved0; /**< Must be NULL */ + unsigned int srcPitch; /**< Source pitch (ignored when src is array) */ + unsigned int srcHeight; /**< Source height (ignored when src is array; may be 0 if Depth==1) */ + + unsigned int dstXInBytes; /**< Destination X in bytes */ + unsigned int dstY; /**< Destination Y */ + unsigned int dstZ; /**< Destination Z */ + unsigned int dstLOD; /**< Destination LOD */ + CUmemorytype dstMemoryType; /**< Destination memory type (host, device, array) */ + void *dstHost; /**< Destination host pointer */ + CUdeviceptr_v1 dstDevice; /**< Destination device pointer */ + CUarray dstArray; /**< Destination array reference */ + void *reserved1; /**< Must be NULL */ + unsigned int dstPitch; /**< Destination pitch (ignored when dst is array) */ + unsigned int dstHeight; /**< Destination height (ignored when dst is array; may be 0 if Depth==1) */ + + unsigned int WidthInBytes; /**< Width of 3D memory copy in bytes */ + unsigned int Height; /**< Height of 3D memory copy */ + unsigned int Depth; /**< Depth of 3D memory copy */ + } CUDA_MEMCPY3D_v1; + + typedef struct CUDA_ARRAY_DESCRIPTOR_v1_st + { + unsigned int Width; /**< Width of array */ + unsigned int Height; /**< Height of array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + } CUDA_ARRAY_DESCRIPTOR_v1; + + typedef struct CUDA_ARRAY3D_DESCRIPTOR_v1_st + { + unsigned int Width; /**< Width of 3D array */ + unsigned int Height; /**< Height of 3D array */ + unsigned int Depth; /**< Depth of 3D array */ + + CUarray_format Format; /**< Array format */ + unsigned int NumChannels; /**< Channels per array element */ + unsigned int Flags; /**< Flags */ + } CUDA_ARRAY3D_DESCRIPTOR_v1; + + CUresult CUDAAPI cuDeviceTotalMem(unsigned int *bytes, CUdevice dev); + CUresult CUDAAPI cuCtxCreate(CUcontext *pctx, unsigned int flags, CUdevice dev); + CUresult CUDAAPI cuModuleGetGlobal(CUdeviceptr_v1 *dptr, unsigned int *bytes, CUmodule hmod, const char *name); + CUresult CUDAAPI cuMemGetInfo(unsigned int *free, unsigned int *total); + CUresult CUDAAPI cuMemAlloc(CUdeviceptr_v1 *dptr, unsigned int bytesize); + CUresult CUDAAPI cuMemAllocPitch(CUdeviceptr_v1 *dptr, unsigned int *pPitch, unsigned int WidthInBytes, unsigned int Height, unsigned int ElementSizeBytes); + CUresult CUDAAPI cuMemFree(CUdeviceptr_v1 dptr); + CUresult CUDAAPI cuMemGetAddressRange(CUdeviceptr_v1 *pbase, unsigned int *psize, CUdeviceptr_v1 dptr); + CUresult CUDAAPI cuMemAllocHost(void **pp, unsigned int bytesize); + CUresult CUDAAPI cuMemHostGetDevicePointer(CUdeviceptr_v1 *pdptr, void *p, unsigned int Flags); + CUresult CUDAAPI cuMemcpyHtoD(CUdeviceptr_v1 dstDevice, const void *srcHost, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyDtoH(void *dstHost, CUdeviceptr_v1 srcDevice, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyDtoD(CUdeviceptr_v1 dstDevice, CUdeviceptr_v1 srcDevice, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyDtoA(CUarray dstArray, unsigned int dstOffset, CUdeviceptr_v1 srcDevice, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyAtoD(CUdeviceptr_v1 dstDevice, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyHtoA(CUarray dstArray, unsigned int dstOffset, const void *srcHost, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyAtoH(void *dstHost, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyAtoA(CUarray dstArray, unsigned int dstOffset, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount); + CUresult CUDAAPI cuMemcpyHtoAAsync(CUarray dstArray, unsigned int dstOffset, const void *srcHost, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyAtoHAsync(void *dstHost, CUarray srcArray, unsigned int srcOffset, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy2D(const CUDA_MEMCPY2D_v1 *pCopy); + CUresult CUDAAPI cuMemcpy2DUnaligned(const CUDA_MEMCPY2D_v1 *pCopy); + CUresult CUDAAPI cuMemcpy3D(const CUDA_MEMCPY3D_v1 *pCopy); + CUresult CUDAAPI cuMemcpyHtoDAsync(CUdeviceptr_v1 dstDevice, const void *srcHost, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyDtoHAsync(void *dstHost, CUdeviceptr_v1 srcDevice, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyDtoDAsync(CUdeviceptr_v1 dstDevice, CUdeviceptr_v1 srcDevice, unsigned int ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy2DAsync(const CUDA_MEMCPY2D_v1 *pCopy, CUstream hStream); + CUresult CUDAAPI cuMemcpy3DAsync(const CUDA_MEMCPY3D_v1 *pCopy, CUstream hStream); + CUresult CUDAAPI cuMemsetD8(CUdeviceptr_v1 dstDevice, unsigned char uc, unsigned int N); + CUresult CUDAAPI cuMemsetD16(CUdeviceptr_v1 dstDevice, unsigned short us, unsigned int N); + CUresult CUDAAPI cuMemsetD32(CUdeviceptr_v1 dstDevice, unsigned int ui, unsigned int N); + CUresult CUDAAPI cuMemsetD2D8(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, unsigned char uc, unsigned int Width, unsigned int Height); + CUresult CUDAAPI cuMemsetD2D16(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, unsigned short us, unsigned int Width, unsigned int Height); + CUresult CUDAAPI cuMemsetD2D32(CUdeviceptr_v1 dstDevice, unsigned int dstPitch, unsigned int ui, unsigned int Width, unsigned int Height); + CUresult CUDAAPI cuArrayCreate(CUarray *pHandle, const CUDA_ARRAY_DESCRIPTOR_v1 *pAllocateArray); + CUresult CUDAAPI cuArrayGetDescriptor(CUDA_ARRAY_DESCRIPTOR_v1 *pArrayDescriptor, CUarray hArray); + CUresult CUDAAPI cuArray3DCreate(CUarray *pHandle, const CUDA_ARRAY3D_DESCRIPTOR_v1 *pAllocateArray); + CUresult CUDAAPI cuArray3DGetDescriptor(CUDA_ARRAY3D_DESCRIPTOR_v1 *pArrayDescriptor, CUarray hArray); + CUresult CUDAAPI cuTexRefSetAddress(unsigned int *ByteOffset, CUtexref hTexRef, CUdeviceptr_v1 dptr, unsigned int bytes); + CUresult CUDAAPI cuTexRefSetAddress2D(CUtexref hTexRef, const CUDA_ARRAY_DESCRIPTOR_v1 *desc, CUdeviceptr_v1 dptr, unsigned int Pitch); + CUresult CUDAAPI cuTexRefGetAddress(CUdeviceptr_v1 *pdptr, CUtexref hTexRef); + CUresult CUDAAPI cuGraphicsResourceGetMappedPointer(CUdeviceptr_v1 *pDevPtr, unsigned int *pSize, CUgraphicsResource resource); + + CUresult CUDAAPI cuCtxDestroy(CUcontext ctx); + CUresult CUDAAPI cuCtxPopCurrent(CUcontext *pctx); + CUresult CUDAAPI cuCtxPushCurrent(CUcontext ctx); + CUresult CUDAAPI cuStreamDestroy(CUstream hStream); + CUresult CUDAAPI cuEventDestroy(CUevent hEvent); + CUresult CUDAAPI cuDevicePrimaryCtxRelease(CUdevice dev); + CUresult CUDAAPI cuDevicePrimaryCtxReset(CUdevice dev); + CUresult CUDAAPI cuDevicePrimaryCtxSetFlags(CUdevice dev, unsigned int flags); + + CUresult CUDAAPI cuMemcpyHtoD_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount); + CUresult CUDAAPI cuMemcpyDtoH_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount); + CUresult CUDAAPI cuMemcpyDtoD_v2(CUdeviceptr dstDevice, CUdeviceptr srcDevice, size_t ByteCount); + CUresult CUDAAPI cuMemcpyDtoA_v2(CUarray dstArray, size_t dstOffset, CUdeviceptr srcDevice, size_t ByteCount); + CUresult CUDAAPI cuMemcpyAtoD_v2(CUdeviceptr dstDevice, CUarray srcArray, size_t srcOffset, size_t ByteCount); + CUresult CUDAAPI cuMemcpyHtoA_v2(CUarray dstArray, size_t dstOffset, const void *srcHost, size_t ByteCount); + CUresult CUDAAPI cuMemcpyAtoH_v2(void *dstHost, CUarray srcArray, size_t srcOffset, size_t ByteCount); + CUresult CUDAAPI cuMemcpyAtoA_v2(CUarray dstArray, size_t dstOffset, CUarray srcArray, size_t srcOffset, size_t ByteCount); + CUresult CUDAAPI cuMemcpyHtoAAsync_v2(CUarray dstArray, size_t dstOffset, const void *srcHost, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyAtoHAsync_v2(void *dstHost, CUarray srcArray, size_t srcOffset, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy2D_v2(const CUDA_MEMCPY2D *pCopy); + CUresult CUDAAPI cuMemcpy2DUnaligned_v2(const CUDA_MEMCPY2D *pCopy); + CUresult CUDAAPI cuMemcpy3D_v2(const CUDA_MEMCPY3D *pCopy); + CUresult CUDAAPI cuMemcpyHtoDAsync_v2(CUdeviceptr dstDevice, const void *srcHost, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyDtoHAsync_v2(void *dstHost, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyDtoDAsync_v2(CUdeviceptr dstDevice, CUdeviceptr srcDevice, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy2DAsync_v2(const CUDA_MEMCPY2D *pCopy, CUstream hStream); + CUresult CUDAAPI cuMemcpy3DAsync_v2(const CUDA_MEMCPY3D *pCopy, CUstream hStream); + CUresult CUDAAPI cuMemsetD8_v2(CUdeviceptr dstDevice, unsigned char uc, size_t N); + CUresult CUDAAPI cuMemsetD16_v2(CUdeviceptr dstDevice, unsigned short us, size_t N); + CUresult CUDAAPI cuMemsetD32_v2(CUdeviceptr dstDevice, unsigned int ui, size_t N); + CUresult CUDAAPI cuMemsetD2D8_v2(CUdeviceptr dstDevice, size_t dstPitch, unsigned char uc, size_t Width, size_t Height); + CUresult CUDAAPI cuMemsetD2D16_v2(CUdeviceptr dstDevice, size_t dstPitch, unsigned short us, size_t Width, size_t Height); + CUresult CUDAAPI cuMemsetD2D32_v2(CUdeviceptr dstDevice, size_t dstPitch, unsigned int ui, size_t Width, size_t Height); + CUresult CUDAAPI cuMemcpy(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount); + CUresult CUDAAPI cuMemcpyAsync(CUdeviceptr dst, CUdeviceptr src, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpyPeer(CUdeviceptr dstDevice, CUcontext dstContext, CUdeviceptr srcDevice, CUcontext srcContext, size_t ByteCount); + CUresult CUDAAPI cuMemcpyPeerAsync(CUdeviceptr dstDevice, CUcontext dstContext, CUdeviceptr srcDevice, CUcontext srcContext, size_t ByteCount, CUstream hStream); + CUresult CUDAAPI cuMemcpy3DPeer(const CUDA_MEMCPY3D_PEER *pCopy); + CUresult CUDAAPI cuMemcpy3DPeerAsync(const CUDA_MEMCPY3D_PEER *pCopy, CUstream hStream); + + CUresult CUDAAPI cuMemsetD8Async(CUdeviceptr dstDevice, unsigned char uc, size_t N, CUstream hStream); + CUresult CUDAAPI cuMemsetD16Async(CUdeviceptr dstDevice, unsigned short us, size_t N, CUstream hStream); + CUresult CUDAAPI cuMemsetD32Async(CUdeviceptr dstDevice, unsigned int ui, size_t N, CUstream hStream); + CUresult CUDAAPI cuMemsetD2D8Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned char uc, size_t Width, size_t Height, CUstream hStream); + CUresult CUDAAPI cuMemsetD2D16Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned short us, size_t Width, size_t Height, CUstream hStream); + CUresult CUDAAPI cuMemsetD2D32Async(CUdeviceptr dstDevice, size_t dstPitch, unsigned int ui, size_t Width, size_t Height, CUstream hStream); + + CUresult CUDAAPI cuStreamGetPriority(CUstream hStream, int *priority); + CUresult CUDAAPI cuStreamGetFlags(CUstream hStream, unsigned int *flags); + CUresult CUDAAPI cuStreamGetCtx(CUstream hStream, CUcontext *pctx); + CUresult CUDAAPI cuStreamWaitEvent(CUstream hStream, CUevent hEvent, unsigned int Flags); + CUresult CUDAAPI cuStreamAddCallback(CUstream hStream, CUstreamCallback callback, void *userData, unsigned int flags); + CUresult CUDAAPI cuStreamAttachMemAsync(CUstream hStream, CUdeviceptr dptr, size_t length, unsigned int flags); + CUresult CUDAAPI cuStreamQuery(CUstream hStream); + CUresult CUDAAPI cuStreamSynchronize(CUstream hStream); + CUresult CUDAAPI cuEventRecord(CUevent hEvent, CUstream hStream); + CUresult CUDAAPI cuEventRecordWithFlags(CUevent hEvent, CUstream hStream, unsigned int flags); + CUresult CUDAAPI cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams, void **extra); + + + + CUresult CUDAAPI cuLaunchHostFunc(CUstream hStream, CUhostFn fn, void *userData); + CUresult CUDAAPI cuGraphicsMapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); + CUresult CUDAAPI cuGraphicsUnmapResources(unsigned int count, CUgraphicsResource *resources, CUstream hStream); + CUresult CUDAAPI cuStreamWriteValue32(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWaitValue32(CUstream stream, CUdeviceptr addr, cuuint32_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWriteValue64(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + CUresult CUDAAPI cuStreamWaitValue64(CUstream stream, CUdeviceptr addr, cuuint64_t value, unsigned int flags); + CUresult CUDAAPI cuStreamBatchMemOp(CUstream stream, unsigned int count, CUstreamBatchMemOpParams *paramArray, unsigned int flags); + CUresult CUDAAPI cuMemPrefetchAsync(CUdeviceptr devPtr, size_t count, CUdevice dstDevice, CUstream hStream); + CUresult CUDAAPI cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void **kernelParams); + CUresult CUDAAPI cuSignalExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_SIGNAL_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); + CUresult CUDAAPI cuWaitExternalSemaphoresAsync(const CUexternalSemaphore *extSemArray, const CUDA_EXTERNAL_SEMAPHORE_WAIT_PARAMS *paramsArray, unsigned int numExtSems, CUstream stream); + CUresult CUDAAPI cuStreamBeginCapture(CUstream hStream); + CUresult CUDAAPI cuStreamBeginCapture_ptsz(CUstream hStream); + CUresult CUDAAPI cuStreamBeginCapture_v2(CUstream hStream, CUstreamCaptureMode mode); + CUresult CUDAAPI cuStreamEndCapture(CUstream hStream, CUgraph *phGraph); + CUresult CUDAAPI cuStreamIsCapturing(CUstream hStream, CUstreamCaptureStatus *captureStatus); + CUresult CUDAAPI cuStreamGetCaptureInfo(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out); + CUresult CUDAAPI cuStreamGetCaptureInfo_v2(CUstream hStream, CUstreamCaptureStatus *captureStatus_out, cuuint64_t *id_out, CUgraph *graph_out, const CUgraphNode **dependencies_out, size_t *numDependencies_out); + CUresult CUDAAPI cuGraphUpload(CUgraphExec hGraph, CUstream hStream); + CUresult CUDAAPI cuGraphLaunch(CUgraphExec hGraph, CUstream hStream); + CUresult CUDAAPI cuStreamCopyAttributes(CUstream dstStream, CUstream srcStream); + CUresult CUDAAPI cuStreamGetAttribute(CUstream hStream, CUstreamAttrID attr, CUstreamAttrValue *value); + CUresult CUDAAPI cuStreamSetAttribute(CUstream hStream, CUstreamAttrID attr, const CUstreamAttrValue *param); + + CUresult CUDAAPI cuIpcOpenMemHandle(CUdeviceptr *pdptr, CUipcMemHandle handle, unsigned int Flags); + CUresult CUDAAPI cuGraphInstantiate(CUgraphExec *phGraphExec, CUgraph hGraph, CUgraphNode *phErrorNode, char *logBuffer, size_t bufferSize); + CUresult CUDAAPI cuMemMapArrayAsync(CUarrayMapInfo *mapInfoList, unsigned int count, CUstream hStream); + + CUresult CUDAAPI cuMemFreeAsync(CUdeviceptr dptr, CUstream hStream); + CUresult CUDAAPI cuMemAllocAsync(CUdeviceptr *dptr, size_t bytesize, CUstream hStream); + CUresult CUDAAPI cuMemAllocFromPoolAsync(CUdeviceptr *dptr, size_t bytesize, CUmemoryPool pool, CUstream hStream); + + CUresult CUDAAPI cuStreamUpdateCaptureDependencies(CUstream hStream, CUgraphNode *dependencies, size_t numDependencies, unsigned int flags); +#elif defined(__CUDA_API_PER_THREAD_DEFAULT_STREAM) +static inline CUresult cuGetProcAddress_ptsz(const char *symbol, void **funcPtr, int driverVersion, cuuint64_t flags) { + const int procAddressMask = (CU_GET_PROC_ADDRESS_LEGACY_STREAM| + CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM); + if ((flags & procAddressMask) == 0) { + flags |= CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM; + } + return cuGetProcAddress(symbol, funcPtr, driverVersion, flags); +} +#define cuGetProcAddress cuGetProcAddress_ptsz +#endif + +#ifdef __cplusplus +} +#endif + +#if defined(__GNUC__) + #if defined(__CUDA_API_PUSH_VISIBILITY_DEFAULT) + #pragma GCC visibility pop + #endif +#endif + +#undef __CUDA_DEPRECATED + +#endif /* __cuda_cuda_h__ */ diff --git a/pkgs/triton/third_party/cuda/lib/libdevice.10.bc b/pkgs/triton/third_party/cuda/lib/libdevice.10.bc new file mode 100755 index 0000000..b2c75a5 Binary files /dev/null and b/pkgs/triton/third_party/cuda/lib/libdevice.10.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/asanrtl.bc b/pkgs/triton/third_party/rocm/lib/bitcode/asanrtl.bc new file mode 100644 index 0000000..14e6011 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/asanrtl.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.bc b/pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.bc new file mode 100755 index 0000000..fadc5b4 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.patch b/pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.patch new file mode 100644 index 0000000..3d8f633 --- /dev/null +++ b/pkgs/triton/third_party/rocm/lib/bitcode/cuda2gcn.patch @@ -0,0 +1,24 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index b65f1b5..19cc5a9 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -62,12 +62,13 @@ include(OCL) + set(AMDGCN_LIB_LIST) + set(AMDGCN_DEP_LIST) + add_subdirectory(irif) +-add_subdirectory(oclc) +-add_subdirectory(ocml) +-add_subdirectory(ockl) +-add_subdirectory(opencl) +-add_subdirectory(hip) +-add_subdirectory(asanrtl) ++#add_subdirectory(oclc) ++#add_subdirectory(ocml) ++#add_subdirectory(ockl) ++#add_subdirectory(opencl) ++#add_subdirectory(hip) ++#add_subdirectory(asanrtl) ++add_subdirectory(cuda2gcn) + + enable_testing() + add_subdirectory(test/compile) diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/hip.bc b/pkgs/triton/third_party/rocm/lib/bitcode/hip.bc new file mode 100644 index 0000000..d23b145 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/hip.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh b/pkgs/triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh new file mode 100755 index 0000000..a3fc04c --- /dev/null +++ b/pkgs/triton/third_party/rocm/lib/bitcode/make_cuda2gcn.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +set -e + +pushd . + +git clone https://github.com/dfukalov/ROCm-Device-Libs.git +cd ROCm-Device-Libs +git apply ../cuda2gcn.patch +mkdir build +cd build +cmake .. -DCMAKE_PREFIX_PATH=$HOME/.triton/llvm/clang+llvm-14.0.0-x86_64-linux-gnu-ubuntu-18.04 +make -j4 + +popd +cp ROCm-Device-Libs/build/amdgcn/bitcode/cuda2gcn.bc . +rm -rf ROCm-Device-Libs diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/ockl.bc b/pkgs/triton/third_party/rocm/lib/bitcode/ockl.bc new file mode 100644 index 0000000..e662d2b Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/ockl.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_abi_version_400.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_abi_version_400.bc new file mode 100644 index 0000000..cd9bf55 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_abi_version_400.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_abi_version_500.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_abi_version_500.bc new file mode 100644 index 0000000..8d04848 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_abi_version_500.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_off.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_off.bc new file mode 100644 index 0000000..c07272c Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_off.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_on.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_on.bc new file mode 100644 index 0000000..8ffaa29 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_correctly_rounded_sqrt_on.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_daz_opt_off.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_daz_opt_off.bc new file mode 100644 index 0000000..2918ed1 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_daz_opt_off.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_daz_opt_on.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_daz_opt_on.bc new file mode 100644 index 0000000..67b2dbe Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_daz_opt_on.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_finite_only_off.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_finite_only_off.bc new file mode 100644 index 0000000..363d637 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_finite_only_off.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_finite_only_on.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_finite_only_on.bc new file mode 100644 index 0000000..266dff9 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_finite_only_on.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1010.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1010.bc new file mode 100644 index 0000000..032c767 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1010.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1011.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1011.bc new file mode 100644 index 0000000..5cb3121 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1011.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1012.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1012.bc new file mode 100644 index 0000000..615e65b Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1012.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1013.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1013.bc new file mode 100644 index 0000000..5548378 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1013.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1030.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1030.bc new file mode 100644 index 0000000..89517b8 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1030.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1031.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1031.bc new file mode 100644 index 0000000..fb2252f Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1031.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1032.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1032.bc new file mode 100644 index 0000000..ed9244f Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1032.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1033.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1033.bc new file mode 100644 index 0000000..644b57e Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1033.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1034.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1034.bc new file mode 100644 index 0000000..cd15cc6 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1034.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1035.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1035.bc new file mode 100644 index 0000000..c8a8168 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1035.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1036.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1036.bc new file mode 100644 index 0000000..c7a95bc Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1036.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1100.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1100.bc new file mode 100644 index 0000000..565e6cc Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1100.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1101.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1101.bc new file mode 100644 index 0000000..7d0914f Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1101.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1102.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1102.bc new file mode 100644 index 0000000..c73d351 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1102.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1103.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1103.bc new file mode 100644 index 0000000..0b4fd7b Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_1103.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_600.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_600.bc new file mode 100644 index 0000000..7358c08 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_600.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_601.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_601.bc new file mode 100644 index 0000000..327c6a9 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_601.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_602.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_602.bc new file mode 100644 index 0000000..351b59f Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_602.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_700.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_700.bc new file mode 100644 index 0000000..ae16156 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_700.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_701.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_701.bc new file mode 100644 index 0000000..32f5915 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_701.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_702.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_702.bc new file mode 100644 index 0000000..65b9b95 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_702.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_703.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_703.bc new file mode 100644 index 0000000..3e2978c Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_703.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_704.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_704.bc new file mode 100644 index 0000000..3df6bff Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_704.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_705.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_705.bc new file mode 100644 index 0000000..9a60964 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_705.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_801.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_801.bc new file mode 100644 index 0000000..34b1c9d Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_801.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_802.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_802.bc new file mode 100644 index 0000000..7b4d7df Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_802.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_803.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_803.bc new file mode 100644 index 0000000..f208441 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_803.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_805.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_805.bc new file mode 100644 index 0000000..a03a8af Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_805.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_810.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_810.bc new file mode 100644 index 0000000..3e36122 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_810.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_900.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_900.bc new file mode 100644 index 0000000..7664808 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_900.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_902.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_902.bc new file mode 100644 index 0000000..cba8471 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_902.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_904.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_904.bc new file mode 100644 index 0000000..f07daf2 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_904.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_906.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_906.bc new file mode 100644 index 0000000..21b41e6 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_906.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_908.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_908.bc new file mode 100644 index 0000000..7be206f Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_908.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_909.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_909.bc new file mode 100644 index 0000000..5f74426 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_909.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_90a.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_90a.bc new file mode 100644 index 0000000..1bb0d46 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_90a.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_90c.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_90c.bc new file mode 100644 index 0000000..641535f Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_90c.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_940.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_940.bc new file mode 100644 index 0000000..b3a97fd Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_isa_version_940.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_off.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_off.bc new file mode 100644 index 0000000..623906d Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_off.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_on.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_on.bc new file mode 100644 index 0000000..e78c15c Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_unsafe_math_on.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_off.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_off.bc new file mode 100644 index 0000000..d50a050 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_off.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_on.bc b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_on.bc new file mode 100644 index 0000000..1d84e11 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/oclc_wavefrontsize64_on.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/ocml.bc b/pkgs/triton/third_party/rocm/lib/bitcode/ocml.bc new file mode 100644 index 0000000..1cca507 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/ocml.bc differ diff --git a/pkgs/triton/third_party/rocm/lib/bitcode/opencl.bc b/pkgs/triton/third_party/rocm/lib/bitcode/opencl.bc new file mode 100644 index 0000000..5a13437 Binary files /dev/null and b/pkgs/triton/third_party/rocm/lib/bitcode/opencl.bc differ diff --git a/pkgs/triton/tools/__init__.py b/pkgs/triton/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/triton/tools/__pycache__/__init__.cpython-310.pyc b/pkgs/triton/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..71651b9 Binary files /dev/null and b/pkgs/triton/tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/triton/tools/__pycache__/aot.cpython-310.pyc b/pkgs/triton/tools/__pycache__/aot.cpython-310.pyc new file mode 100644 index 0000000..a491742 Binary files /dev/null and b/pkgs/triton/tools/__pycache__/aot.cpython-310.pyc differ diff --git a/pkgs/triton/tools/__pycache__/build_extern.cpython-310.pyc b/pkgs/triton/tools/__pycache__/build_extern.cpython-310.pyc new file mode 100644 index 0000000..f49d0e7 Binary files /dev/null and b/pkgs/triton/tools/__pycache__/build_extern.cpython-310.pyc differ diff --git a/pkgs/triton/tools/__pycache__/disasm.cpython-310.pyc b/pkgs/triton/tools/__pycache__/disasm.cpython-310.pyc new file mode 100644 index 0000000..e83201c Binary files /dev/null and b/pkgs/triton/tools/__pycache__/disasm.cpython-310.pyc differ diff --git a/pkgs/triton/tools/aot.py b/pkgs/triton/tools/aot.py new file mode 100644 index 0000000..f158674 --- /dev/null +++ b/pkgs/triton/tools/aot.py @@ -0,0 +1,122 @@ +import argparse +import sys + +import triton._C.libtriton.triton as libtriton +import triton.compiler.compiler as tc + +if __name__ == '__main__': + + # valid source and target formats + VALID_FORMATS = ['triton-ir', 'triton-gpu-ir', 'llvm-ir', 'ptx', 'amdgcn'] + + # set up the argument parser + # TODO: conditional requirements + parser = argparse.ArgumentParser() + parser.add_argument('src', help="Source file to compile") + parser.add_argument('--target', required=True, + help="Target format, one of: " + ', '.join(VALID_FORMATS)) + parser.add_argument('--sm', type=int, help="Compute capability to compile for") + parser.add_argument('--ptx-version', type=int, help="PTX version to compile for") + parser.add_argument('--gfx', type=str, help="AMDGPU target to compile for") + parser.add_argument('--triple', type=str, help="target triple, for example: amdgcn-amd-amdhsa") + parser.add_argument('--features', type=str, help="target features, for example: +sramecc,-xnack") + parser.add_argument('--num_warps', type=int, help="number of warps to compile ttgir for") + + # parse the args + args = parser.parse_args() + + # TODO: clean-up and re-use triton.compiler primitive functions + # check for validity of format arguments + if args.target not in VALID_FORMATS: + print("Invalid target format: " + args.target) + sys.exit(0) + + # parse source file to MLIR module + context = libtriton.ir.context() + module = libtriton.ir.parse_mlir_module(args.src, context) + module.context = context + + # optimizer triton-ir + module = tc.optimize_ttir(module, arch=args.sm) + if args.target == 'triton-ir': + print(module.str()) + sys.exit(0) + + if not args.num_warps: + args.num_warps = 4 + + # llvm-ir -> amdgcn + if args.target == 'amdgcn': + # auto detect available architecture and features + # if nothing detected, set with default values + arch_details = tc.get_amdgpu_arch_fulldetails() + if not arch_details: + arch_name = "" + arch_triple = "amdgcn-amd-amdhsa" + arch_features = "" + arch_warpsize = 64 + else: + arch_triple, arch_name, arch_features, arch_warpsize = arch_details + + # stop processing if architecture name is not automatically detected and is not set manually + if not args.gfx and not arch_name: + raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation") + + # rewrite default and automatically detected values with manually provided data + if args.gfx: + arch_name = args.gfx + if args.triple: + arch_triple = args.triple + if args.features: + arch_features = args.features + + # triton-ir -> triton-gpu-ir + # use compute_capability == 80 + module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=arch_warpsize) # num_stages=3, compute_capability=80) + module = tc.optimize_ttgir(module, num_stages=3, arch=args.gfx) + # triton-gpu-ir -> llvm-ir + # use compute_capability == 80 + module = tc.ttgir_to_llir(module, extern_libs=None, arch=args.gfx) + # llvm-ir -> amdgcn asm, hsaco binary + module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, arch_name, arch_triple, arch_features) + + print(hsaco_path) + print(module) + sys.exit(0) + + # set arch depending on platform + if args.gfx: + arch = args.gfx + elif args.sm: + arch = args.sm + else: + raise argparse.ArgumentError(None, "Must specify --sm or --gfx for ttgir compilation") + + # triton-ir -> triton-gpu-ir + module = tc.ttir_to_ttgir(module, num_warps=args.num_warps, warpsize=tc.CUDA_DEFAULT_WARP_SIZE) + module = tc.optimize_ttgir(module, num_stages=3, arch=arch) + if args.target == 'triton-gpu-ir': + print(module.str()) + sys.exit(0) + + # triton-gpu-ir -> llvm-ir + module = tc.ttgir_to_llir(module, extern_libs=None, arch=arch) + if args.target == 'llvm-ir': + print(module) + sys.exit(0) + + # llvm-ir -> ptx + if args.target == 'ptx': + if not args.sm: + raise argparse.ArgumentError(None, "Must specify --sm for PTX compilation") + if not args.ptx_version: + raise argparse.ArgumentError(None, "Must specify --ptx-version for PTX compilation") + module = tc.llir_to_ptx(module, arch=args.sm, ptx_version=args.ptx_version) + + # llvm-ir -> amdgcn + if args.target == 'amdgcn': + if not args.gfx: + raise argparse.ArgumentError(None, "Must specify --gfx for AMDGCN compilation") + module, hsaco_path = tc.llir_to_amdgcn_and_hsaco(module, args.gfx) + + print(module) diff --git a/pkgs/triton/tools/build_extern.py b/pkgs/triton/tools/build_extern.py new file mode 100644 index 0000000..f19fbd5 --- /dev/null +++ b/pkgs/triton/tools/build_extern.py @@ -0,0 +1,398 @@ +import argparse +import subprocess +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + + +class Symbol: + _name: str + _op_name: str + _ret_type: str + _arg_names: List[str] + _arg_types: List[str] + + def __init__( + self, + name: str, + op_name: str, + ret_type: str, + arg_names: List[str], + arg_types: List[str], + ) -> None: + ''' + A symbol is a function declaration. + :param name: name of the symbol + :param op_name: name of the operation + :param ret_type: return type of the operation + :param arg_names: names of the arguments + :param arg_types: types of the arguments + ''' + self._name = name + self._op_name = op_name + self._ret_type = ret_type + self._arg_names = list(arg_names) + self._arg_types = list(arg_types) + + @property + def name(self) -> str: + return self._name + + @property + def op_name(self) -> str: + return self._op_name + + @property + def ret_type(self) -> str: + return self._ret_type + + @property + def arg_names(self) -> List[str]: + return self._arg_names + + @property + def arg_types(self) -> List[str]: + return self._arg_types + + +def convert_type(type_str) -> Optional[str]: + if type_str == "i32": + return "int32" + elif type_str == "u32": + return "uint32" + elif type_str == "i64": + return "int64" + elif type_str == "u64": + return "uint64" + elif type_str == "float": + return "fp32" + elif type_str == "double": + return "fp64" + else: + # ignore other types, such as pointer types + return None + + +def to_unsigned(type_str) -> str: + if type_str == "int32": + return "uint32" + elif type_str == "int64": + return "uint64" + else: + return type_str + + +class ExternLibrary(ABC): + _name: str + _path: str + _symbols: Dict[str, Symbol] + _format: bool + _grouping: bool + + def __init__( + self, + name: str, + path: str, + format: bool = True, + grouping: bool = True, + ) -> None: + ''' + Abstract class for extern library. + :param name: name of the library + :param path: path of the library + :param format: whether to format the generated stub file + ''' + self._name = name + self._path = path + self._symbols = {} + self._format = format + self._grouping = grouping + + @property + def name(self) -> str: + return self._name + + @property + def path(self) -> str: + return self._path + + @property + def symbols(self) -> Dict[str, Symbol]: + return self._symbols + + @property + def grouping(self) -> bool: + return self._grouping + + @abstractmethod + def parse_symbols(self, input_file) -> None: + pass + + @abstractmethod + def _output_stubs(self) -> str: + pass + + def generate_stub_file(self, output_dir) -> None: + file_str = self._output_stubs() + if file_str is None or len(file_str) == 0: + raise Exception("file_str is empty") + + output_file = f"{output_dir}/{self._name}.py" + with open(output_file, "w") as f: + f.write(file_str) + f.close() + if self._format: + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], + stdout=subprocess.PIPE).communicate() + subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() + + +class Libdevice(ExternLibrary): + _symbol_groups: Dict[str, List[Symbol]] + + def __init__(self, path) -> None: + ''' + Constructor for Libdevice. + :param path: path of the libdevice library + ''' + super().__init__("libdevice", path) + self._symbol_groups = {} + self.is_pure = True + + @staticmethod + def _extract_symbol(line) -> Optional[Symbol]: + # Extract symbols from line in the following format: + # "define [internal] @(,)" + entries = line.split("@") + ret_str = entries[0] + func_str = entries[1] + # Get ret_type, skip internal symbols + ret_strs = ret_str.split() + if ret_strs[1] == "internal": + return None + ret_type = convert_type(ret_strs[1]) + if ret_type is None: + return None + # Get function name + func_strs = func_str.split("(") + func_name = func_strs[0].replace("@", "") + op_name = func_name.replace("__nv_", "") + if 'ieee' in op_name: + return None + # Get arg_types + arg_strs = func_strs[1].split(",") + arg_types = [] + arg_names = [] + for i, arg_str in enumerate(arg_strs): + arg_type = convert_type(arg_str.split()[0]) + if arg_type is None: + return None + arg_name = 'arg' + str(i) + arg_types.append(arg_type) + arg_names.append(arg_name) + if op_name == "sad": + # Special case for sad, where the last argument is an unsigned int + arg_types[-1] = to_unsigned(arg_types[-1]) + elif op_name.startswith("u"): + # LLVM does not differentiate between signed and unsigned integer type. + # We have to convert the types to unsigned + ret_type = to_unsigned(ret_type) + for i, arg_type in enumerate(arg_types): + arg_types[i] = to_unsigned(arg_type) + return Symbol(func_name, op_name, ret_type, arg_names, arg_types) + + def _group_symbols(self) -> None: + symbol_set = {} + for symbol in self._symbols.values(): + op_name = symbol.op_name + symbol_set[op_name] = symbol + + # Group functions together by renaming. + renaming = { + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', + 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn', + 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', + 'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin', + 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', + 'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt', + 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', + 'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi', + 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', + 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru', + 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', + 'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx', + 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', + 'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs', + 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', + 'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor', + 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', + 'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', + 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', + 'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', + 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', + 'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10', + 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', + 'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max', + 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', + 'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', + 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru', + 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', + 'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi', + 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter', + 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', + 'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', + 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd', + 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', + 'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', + 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', + 'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn', + 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', + 'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh', + 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', + 'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', + 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz', + 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', + 'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', + 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', + 'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn' + } + + for symbol in self._symbols.values(): + op_name = symbol.op_name + if op_name in renaming: + op_name = renaming[op_name] + symbol._op_name = op_name + if op_name in self._symbol_groups: + self._symbol_groups[op_name].append(symbol) + else: + self._symbol_groups[op_name] = [symbol] + + def parse_symbols(self, input_file) -> None: + if len(self.symbols) > 0: + return + output = subprocess.check_output(["grep", "define", input_file]).decode().splitlines() + for line in output: + symbol = self._extract_symbol(line) + if symbol is None: + continue + self._symbols[symbol.name] = symbol + + self._group_symbols() + + def _output_stubs(self) -> str: + # Generate python functions in the following format: + # @extern.extern + # def (, _builder=None): + # arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}} + # return core.extern_elementwise("libdevice", , , , _builder) + import_str = "from . import core\n" + import_str += "import os\n" + import_str += "import functools\n" + + header_str = "" + header_str += "@functools.lru_cache()\n" + header_str += "def libdevice_path():\n" + header_str += " import torch\n" + header_str += " third_party_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), \"..\", \"third_party\")\n" + header_str += " if torch.version.hip is None:\n" + header_str += " default = os.path.join(third_party_dir, \"cuda\", \"lib\", \"libdevice.10.bc\")\n" + header_str += " else:\n" + header_str += " default = ''\n" + header_str += " return os.getenv(\"TRITON_LIBDEVICE_PATH\", default)\n" + func_str = "" + for symbols in self._symbol_groups.values(): + func_str += "@core.extern\n" + func_name_str = f"def {symbols[0].op_name}(" + for arg_name in symbols[0].arg_names: + func_name_str += f"{arg_name}, " + func_name_str += "_builder=None):\n" + + return_str = f"\treturn core.extern_elementwise(\"{self._name}\", libdevice_path(), [" + for arg_name in symbols[0].arg_names: + return_str += f"{arg_name}, " + return_str += "], \n" + + arg_type_symbol_dict_str = "{" + for symbol in symbols: + arg_type_symbol_dict_str += "(" + for arg_type in symbol.arg_types: + arg_type_symbol_dict_str += f'core.dtype("{arg_type}"),' + ret_type = f'core.dtype("{symbol.ret_type}")' + arg_type_symbol_dict_str += "): (\"" + symbol.name + "\", " + ret_type + "),\n" + arg_type_symbol_dict_str += "}" + + return_str += arg_type_symbol_dict_str + return_str += f", is_pure={self.is_pure}" + return_str += ", _builder=_builder)\n" + + func_str += func_name_str + return_str + "\n" + file_str = import_str + header_str + func_str + + return file_str + + +class LLVMDisassembler: + _path: str + _ll_file: str + + def __init__(self, path) -> None: + ''' + Invoke llvm-dis to disassemble the given file. + :param path: path to llvm-dis + ''' + self._path = path + self._ll_file = "/tmp/extern_lib.ll" + + def disasm(self, lib_path: str) -> None: + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], + stdout=subprocess.PIPE).communicate() + + @property + def ll_file(self) -> str: + return self._ll_file + + @property + def path(self) -> str: + return self._path + + +extern_libs = ["libdevice"] + + +def build( + llvm_dis_path: str, + lib_path: str, + lib_name: str, + output_dir: str, +) -> None: + ''' + Interface function to build the library file. + :param llvm_dis_path: path to the llvm-dis binary + :param lib_path: path to the external library file + :param lib_name: name of the library + :param output_dir: path to the output directory + ''' + if lib_name == "libdevice": + extern_lib = Libdevice(lib_path) + else: + raise Exception(f"Unknown extern library: {lib_name}") + + llvm_disassembler = LLVMDisassembler(llvm_dis_path) + llvm_disassembler.disasm(lib_path) + + extern_lib.parse_symbols(llvm_disassembler.ll_file) + extern_lib.generate_stub_file(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--llvm-dis", dest="llvm_dis_path", help="Path to llvm-dis", default="llvm-dis") + parser.add_argument("--lib-path", dest="lib_path", help="Path to the extern library") + parser.add_argument("--lib-name", dest="lib_name", help="Name of the extern library") + parser.add_argument("--output", dest="output_dir", help="Output file path", default="/tmp/") + args = parser.parse_args() + + build(args.llvm_dis_path, args.lib_path, args.lib_name, args.output_dir) diff --git a/pkgs/triton/tools/disasm.py b/pkgs/triton/tools/disasm.py new file mode 100644 index 0000000..24a0787 --- /dev/null +++ b/pkgs/triton/tools/disasm.py @@ -0,0 +1,122 @@ +# MIT License + +# Copyright (c) 2020 Da Yan @ HKUST + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import re +import subprocess + +FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') +SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') +FNAME_RE = re.compile(r'\s*Function : (\w+)\s*') +BRA_RE = re.compile(r'(.*BRA(?:\.U)? )(0x\w+);') + + +def parseCtrl(sline): + enc = int(SLINE_RE.match(sline).group(1), 16) + stall = (enc >> 41) & 0xf + yld = (enc >> 45) & 0x1 + wrtdb = (enc >> 46) & 0x7 + readb = (enc >> 49) & 0x7 + watdb = (enc >> 52) & 0x3f + + yld_str = 'Y' if yld == 0 else '-' + wrtdb_str = '-' if wrtdb == 7 else str(wrtdb) + readb_str = '-' if readb == 7 else str(readb) + watdb_str = '--' if watdb == 0 else f'{watdb:02d}' + return f'{watdb_str}:{readb_str}:{wrtdb_str}:{yld_str}:{stall:x}' + + +def processSassLines(fline, sline, labels): + asm = FLINE_RE.match(fline).group(1) + # Remove tailing space + if asm.endswith(" ;"): + asm = asm[:-2] + ";" + ctrl = parseCtrl(sline) + # BRA target address + if BRA_RE.match(asm) is not None: + target = int(BRA_RE.match(asm).group(2), 16) + if target in labels: + pass + else: + labels[target] = len(labels) + return (f'{ctrl}', f'{asm}') + + +def extract(file_path, fun): + if fun is None: + sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) + else: + sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) + sass_lines = sass_str.splitlines() + line_idx = 0 + while line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + # format: + # function : + # .headerflags: ... + # /*0000*/ asmstr /*0x...*/ + # /*0x...*/ + + # Looking for new function header (function: ) + while FNAME_RE.match(line) is None: + line_idx += 1 + if line_idx < len(sass_lines): + line = sass_lines[line_idx].decode() + else: + return + + fname = FNAME_RE.match(line).group(1) + ret = '' + ret += f'Function:{fname}\n' + line_idx += 2 # bypass .headerflags + line = sass_lines[line_idx].decode() + # Remapping address to label + labels = {} # address -> label_idx + # store sass asm in buffer and them print them (for labels) + # (ctrl, asm) + asm_buffer = [] + while FLINE_RE.match(line) is not None: + # First line (Offset ASM Encoding) + fline = sass_lines[line_idx].decode() + line_idx += 1 + # Second line (Encoding) + sline = sass_lines[line_idx].decode() + line_idx += 1 + asm_buffer.append(processSassLines(fline, sline, labels)) + # peek the next line + line = sass_lines[line_idx].decode() + # Print sass + # label naming convention: LBB#i + for idx, (ctrl, asm) in enumerate(asm_buffer): + # Print label if this is BRA target + offset = idx * 16 + if offset in labels: + label_name = f'LBB{labels[offset]}' + ret += f'{label_name}:\n' + ret += ctrl + '\t' + # if this is BRA, remap offset to label + if BRA_RE.match(asm): + target = int(BRA_RE.match(asm).group(2), 16) + target_name = f'LBB{labels[target]}' + asm = BRA_RE.sub(rf'\1{target_name};', asm) + ret += asm + '\n' + ret += '\n' + return ret diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/INSTALLER b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/INSTALLER new file mode 100644 index 0000000..a1b589e --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/LICENSE b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/LICENSE new file mode 100644 index 0000000..4555634 --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/LICENSE @@ -0,0 +1,35 @@ +From xFormers: + +Copyright (c) Facebook, Inc. and its affiliates + + +=== + +BSD 3-Clause License + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + and IDIAP Research Institute nor the names of its contributors may be + used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/METADATA b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/METADATA new file mode 100644 index 0000000..5dd758b --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/METADATA @@ -0,0 +1,21 @@ +Metadata-Version: 2.1 +Name: xformers +Version: 0.0.22+corex.4.1.2 +Summary: XFormers: A collection of composable Transformer building blocks. +Home-page: https://facebookresearch.github.io/xformers/ +Author: Facebook AI Research +Author-email: oncall+xformers@xmail.facebook.com +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: License :: OSI Approved :: BSD License +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Classifier: Operating System :: OS Independent +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: torch>=1.12 +Requires-Dist: numpy + +XFormers: A collection of composable Transformer building blocks.XFormers aims at being able to reproduce most architectures in the Transformer-family SOTA,defined as compatible and combined building blocks as opposed to monolithic models diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/RECORD b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/RECORD new file mode 100644 index 0000000..b8bb516 --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/RECORD @@ -0,0 +1,363 @@ +xformers-0.0.22+corex.4.1.2.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +xformers-0.0.22+corex.4.1.2.dist-info/LICENSE,sha256=iwaVhtHd3wMc1GkkT2Y_GzmnThI1eC5iVcJJQ-ACt-o,1610 +xformers-0.0.22+corex.4.1.2.dist-info/METADATA,sha256=fTmXl3kZ0VQI_7QR-U9dx0ES647Q0KZLVfJanihmlwk,1017 +xformers-0.0.22+corex.4.1.2.dist-info/RECORD,, +xformers-0.0.22+corex.4.1.2.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers-0.0.22+corex.4.1.2.dist-info/WHEEL,sha256=jkyzK-PMDGe6NpviSR1tpgon24csQEXDXqxhDgOHeYM,105 +xformers-0.0.22+corex.4.1.2.dist-info/direct_url.json,sha256=svOh1m0vzTEhcbA3XTXraZREXNRvU_UWMLUqCS4R57Y,287 +xformers-0.0.22+corex.4.1.2.dist-info/top_level.txt,sha256=4Px1VcGhKk0j3XhKXjA8HtTm6EQOb0hazeJ5nQsNlKk,9 +xformers/_C.so,sha256=LYl7l0RtMmGbnj8bT0Ator75GYQNUqZzHhv1ptVvrc4,4601048 +xformers/_C_flashattention.so,sha256=wZ9qztIaJz4LZaeDLygaWGtRD_2lqp6v2na_udk8rNk,15603472 +xformers/__init__.py,sha256=zy8xg8uRhpu-EnSJhetKTsdDO2w-KlErxq2EB6TT-ec,1402 +xformers/__pycache__/__init__.cpython-310.pyc,, +xformers/__pycache__/_cpp_lib.cpython-310.pyc,, +xformers/__pycache__/_deprecation_warning.cpython-310.pyc,, +xformers/__pycache__/checkpoint.cpython-310.pyc,, +xformers/__pycache__/info.cpython-310.pyc,, +xformers/__pycache__/test.cpython-310.pyc,, +xformers/__pycache__/utils.cpython-310.pyc,, +xformers/__pycache__/version.cpython-310.pyc,, +xformers/_cpp_lib.py,sha256=OlWl4d6dAzFdacL9Yy3-8JCX7jGc6iEs-2IU3-Z8-cc,4592 +xformers/_deprecation_warning.py,sha256=a2-JfNsT7EGLG5ZFarz7GTdnMe6EA3cli4BNfr3Wtog,456 +xformers/_flash_attn/__init__.py,sha256=zS-gWdX4IxFM25nuBGkPOUh9g3APKuG3t443yPojQt0,442 +xformers/_flash_attn/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/bert_padding.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-310.pyc,, +xformers/_flash_attn/__pycache__/fused_softmax.cpython-310.pyc,, +xformers/_flash_attn/bert_padding.py,sha256=BQOzn0hGBKdxAaJF8DGE1Z-RTiL_vfvtwYZodkUn9Po,5898 +xformers/_flash_attn/flash_attn_interface.py,sha256=wyrdPFEcSpq-f8vcQKTCJfX9xCubNaU1SgwVTaI2pyY,57278 +xformers/_flash_attn/flash_attn_triton.py,sha256=8dGUEA_MPBYwJSGCVVR1UWsIxEreiGtuXKf1sfuonR0,38148 +xformers/_flash_attn/flash_attn_triton_og.py,sha256=YK6eAKdIjUu8qrs17KVtXEBGpkUBCS7W8jqZXIRmO24,10593 +xformers/_flash_attn/flash_blocksparse_attention.py,sha256=rpGkY5gyc91-j_WlX32cySJsOA7IIRno6yo1LPYTtwU,6819 +xformers/_flash_attn/flash_blocksparse_attn_interface.py,sha256=BNmjVPk_DegsNRF0FjRam5K4aMdqtJpKtNrp9oTZXQI,7036 +xformers/_flash_attn/fused_softmax.py,sha256=qn8cNitRkNp0sybbMIG6Vr2qNthVhgiwYH6KamVvx9o,7902 +xformers/_flash_attn/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers/_flash_attn/layers/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-310.pyc,, +xformers/_flash_attn/layers/__pycache__/rotary.cpython-310.pyc,, +xformers/_flash_attn/layers/patch_embed.py,sha256=To1s8Y1v-PodYQTeDe6V9bSS6-YmMpLv21Qc2e41MS0,2039 +xformers/_flash_attn/layers/rotary.py,sha256=gAfh-ADlroxyR86Mbe6o8JyPLh8M9KDy6j6SvfJSuzY,15657 +xformers/_flash_attn/losses/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers/_flash_attn/losses/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-310.pyc,, +xformers/_flash_attn/losses/cross_entropy.py,sha256=_wBa5izjAGVilYzhA66FIKyC9L9vkI_rYXNWj5DrCW4,6697 +xformers/_flash_attn/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers/_flash_attn/models/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/bert.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/falcon.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/gptj.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/llama.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc,, +xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc,, +xformers/_flash_attn/models/bert.py,sha256=u3_DEPvdANnr5BuCMlYA12-PU7dVsiz9GCZs85ofS0g,26979 +xformers/_flash_attn/models/falcon.py,sha256=PYGFSEvN3Pev4wq3dFX2bsIDyD4OLisi4_YSYcDnQoY,5940 +xformers/_flash_attn/models/gpt.py,sha256=xjDh9V_QUU-uF087EmultEkjKwxt8KTpvJCLXugR0EI,40471 +xformers/_flash_attn/models/gpt_neox.py,sha256=krpBeWnG2jwj9tEBwsqONILbcRV16K2KfZlhvFbNnFw,5025 +xformers/_flash_attn/models/gptj.py,sha256=8nk3LUIwWXaj9IqcEznhxZIrHvbO4L78wvx-H5F2dbg,4365 +xformers/_flash_attn/models/llama.py,sha256=f_ZyXFlGvBRwnY8dNUjnJ25XjKnUlO2pfs6b9zqSYN0,5761 +xformers/_flash_attn/models/opt.py,sha256=fS2Ij55wl3YHLJMJs9KS_7IgAd_7uLLrRpd2nnxlfHw,5130 +xformers/_flash_attn/models/vit.py,sha256=HdbPeiPxMm-GKdRXbzsbIw_kY4-KcaAq7leIjx8LRgs,13621 +xformers/_flash_attn/modules/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers/_flash_attn/modules/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/modules/__pycache__/block.cpython-310.pyc,, +xformers/_flash_attn/modules/__pycache__/embedding.cpython-310.pyc,, +xformers/_flash_attn/modules/__pycache__/mha.cpython-310.pyc,, +xformers/_flash_attn/modules/__pycache__/mlp.cpython-310.pyc,, +xformers/_flash_attn/modules/block.py,sha256=2f-FzDwlHwOIO8weguKUVPOZWuoR28HGF9KgdfJK6i0,16615 +xformers/_flash_attn/modules/embedding.py,sha256=0jKUwCGmTj5QTrxDW4WtkEMJOWMR38bIhM2m8d095k4,8620 +xformers/_flash_attn/modules/mha.py,sha256=H8dZNfMOwEyEUvh7RpRQcOouFYD9jbESPMnojum36sw,37920 +xformers/_flash_attn/modules/mlp.py,sha256=qFk2AUaqfFhoSR6KDxtw9THsvbYvHCCvbJj-sz_T4c0,3596 +xformers/_flash_attn/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers/_flash_attn/ops/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/ops/__pycache__/activations.cpython-310.pyc,, +xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-310.pyc,, +xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-310.pyc,, +xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-310.pyc,, +xformers/_flash_attn/ops/activations.py,sha256=nJAeKTW8GMqruDbLViQtpAUyGGuTgjao1Hz-j1D60dc,3002 +xformers/_flash_attn/ops/fused_dense.py,sha256=4IGbm3acZEF7Ale4xoTp9gGizRZKaIhEV0sF1X-xuYU,26087 +xformers/_flash_attn/ops/layer_norm.py,sha256=J_EIoi1wcTSUsFC2lhD7y4Tf9yd4YlgbvChh8HXLFGg,19306 +xformers/_flash_attn/ops/rms_norm.py,sha256=6w7dJ96WLRHyGwRWdACEg3rhJKZ18o_XitpSspDAepA,3672 +xformers/_flash_attn/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +xformers/_flash_attn/utils/__pycache__/__init__.cpython-310.pyc,, +xformers/_flash_attn/utils/__pycache__/benchmark.cpython-310.pyc,, +xformers/_flash_attn/utils/__pycache__/distributed.cpython-310.pyc,, +xformers/_flash_attn/utils/__pycache__/generation.cpython-310.pyc,, +xformers/_flash_attn/utils/__pycache__/pretrained.cpython-310.pyc,, +xformers/_flash_attn/utils/benchmark.py,sha256=52C9sDyjBsdw8iUWP0wqVFsuWpB7DZqkGECtibQaezA,5909 +xformers/_flash_attn/utils/distributed.py,sha256=d3Uc_h4gQO7I-LimfexsXfIyS8i_jRIPMkC1nyI1SQ8,5545 +xformers/_flash_attn/utils/generation.py,sha256=Y4ezDnl63zHqcwB6OXUxKbTJ3CuzZJpgNEsp2Z1SlLM,14105 +xformers/_flash_attn/utils/pretrained.py,sha256=aPWK8RLD3Wh4JCHgEZlYNFhYP69HtS3Og67xO2BS_ZA,1824 +xformers/benchmarks/LRA/__init__.py,sha256=BBoPvMfJjZ0Fi25JCRS8sPYUfLbHneLncV3JeGDIGHg,198 +xformers/benchmarks/LRA/__pycache__/__init__.cpython-310.pyc,, +xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-310.pyc,, +xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-310.pyc,, +xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-310.pyc,, +xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-310.pyc,, +xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-310.pyc,, +xformers/benchmarks/LRA/batch_fetch_results.py,sha256=QsTgzFU45VP636TPiewnVCKD57bOR_YkjZq2KjMLzqc,3508 +xformers/benchmarks/LRA/batch_submit.py,sha256=7BgzEgSCJNN4vaQ8m-hEY2x5DxSkaPzp_Jorr2fS5p4,1685 +xformers/benchmarks/LRA/code/__init__.py,sha256=BBoPvMfJjZ0Fi25JCRS8sPYUfLbHneLncV3JeGDIGHg,198 +xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-310.pyc,, +xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-310.pyc,, +xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-310.pyc,, +xformers/benchmarks/LRA/code/dataset.py,sha256=6iMOhcM8rY75Ub_t3fopBxGjo1zCWwSIRAC3eHqAsIs,1403 +xformers/benchmarks/LRA/code/model_wrapper.py,sha256=18YKbVAASeoaGZqhM-Ps3scvblLbVoImHWKJqY7CQjw,9777 +xformers/benchmarks/LRA/run_grid_search.py,sha256=yLhNnQv4HDbEdQVApCdJiZzFes9VeGZjPOafHAaZ5Ew,5327 +xformers/benchmarks/LRA/run_tasks.py,sha256=EK8MfMPa5hMRTJIW3EwMUtZj_dfrTkxlGa0tnIMLyx4,9230 +xformers/benchmarks/LRA/run_with_submitit.py,sha256=WUtQJLYWnVVcp6mbmBQuqiA3pdvayEpZzqEpW5FZRz0,4612 +xformers/benchmarks/__init__.py,sha256=BBoPvMfJjZ0Fi25JCRS8sPYUfLbHneLncV3JeGDIGHg,198 +xformers/benchmarks/__pycache__/__init__.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_blocksparse_transformers.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_causal_blocksparse.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_core.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_indexing.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_mem_eff_attn_decoder.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_mlp.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_revnet.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_transformer.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_triton_blocksparse.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_triton_dropout.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_triton_fused_linear.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_triton_layernorm.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_triton_softmax.cpython-310.pyc,, +xformers/benchmarks/__pycache__/benchmark_triton_stride_sum.cpython-310.pyc,, +xformers/benchmarks/__pycache__/utils.cpython-310.pyc,, +xformers/benchmarks/benchmark_attn_decoding.py,sha256=0VhKGyWE7mmDhuwi_Q0LZoBOvys2GIUhrkRW6B2qrd8,4489 +xformers/benchmarks/benchmark_blocksparse_transformers.py,sha256=EW_Nvjn2v6ILwMRpfa4qt8LqfgMwwlri4ZP0DtITGjY,37484 +xformers/benchmarks/benchmark_causal_blocksparse.py,sha256=TjTib239LH9RFkseuV_Ks0lUtAiSi8OlMgI27bXPLcw,4885 +xformers/benchmarks/benchmark_core.py,sha256=nH9mlneAh1bk2rXtqvPttiUL2BwnSmw7SAh4VKRJaP4,8527 +xformers/benchmarks/benchmark_indexing.py,sha256=y5xzpoZrgtujWZ9HMmLCNv9MwKhC-LYFS6Tw3Ns5dbE,6921 +xformers/benchmarks/benchmark_mem_eff_attention.py,sha256=XWfw19UyMGscAuFLDru9NS1uUgrlBLH7eAWII0Pkn4w,8966 +xformers/benchmarks/benchmark_mem_eff_attn_decoder.py,sha256=3gfOP1xpEDryMXPb5kPjZrIaSikExMlrEGQmPSkmhNw,5102 +xformers/benchmarks/benchmark_mlp.py,sha256=-3cvmgWAqPGEMpU6hUzffxmg7uDBxeJs6VRENHNw7tI,4199 +xformers/benchmarks/benchmark_multi_head_dispatch.py,sha256=1ZAnVLx9rIRzRxgOx50tSdGZSHveH-n4aTU7wjJnlUc,3365 +xformers/benchmarks/benchmark_nystrom_utils.py,sha256=vJnyXhAmZOd4N_H4g8ZRxRsQgEwjTQX-jRGoBTdTuqk,3174 +xformers/benchmarks/benchmark_revnet.py,sha256=qG7SxXpoJ-bi5cajWukyft7qVprmVfqqtx9PRQqIdI8,2671 +xformers/benchmarks/benchmark_sddmm.py,sha256=zX8Ytv2EzTezjGUlb74T24QiGIxfMtD3A8NQL3NJDeo,3658 +xformers/benchmarks/benchmark_swiglu.py,sha256=J_bopnCezQck4ohtWXBErSTQNStkbuSiMmRqCHPrYXo,4181 +xformers/benchmarks/benchmark_transformer.py,sha256=qcM64c4xHA97fcUoz_s4itWhyAjlxPAcv0Q6OtzX6ew,4385 +xformers/benchmarks/benchmark_triton_blocksparse.py,sha256=TaDxcc9GTRnOnJNe4EkItvNwWG9amgK1AxPQD9i2Zbg,5099 +xformers/benchmarks/benchmark_triton_dropout.py,sha256=BiuyT-UJ3h-i35g2uCcjzqxuOBFj3XgXeqpnQ8biJLs,3477 +xformers/benchmarks/benchmark_triton_fused_linear.py,sha256=Bd7NDc4APpEit0_WLxtjw9_I3qeT-evotEcFMCnX6y8,5639 +xformers/benchmarks/benchmark_triton_layernorm.py,sha256=8ZxQ6FpIyddRCBak6q9QznLjLVYmpUGS511q9YfPa9w,2683 +xformers/benchmarks/benchmark_triton_softmax.py,sha256=OpgYre8x-BbbxLiuk6RGi3EUo2pDrO7EjLM61THa0RE,2182 +xformers/benchmarks/benchmark_triton_stride_sum.py,sha256=GQbaltBoxR-OW4tJ-l-nosO83buVTxQkMjPomXGUmw4,1823 +xformers/benchmarks/utils.py,sha256=Df9__YGAsCljtpYe-93_4D745UMHe-PZmi09gz70UBw,21722 +xformers/checkpoint.py,sha256=Inroxy2JZwC-KKDoG2ZDGJ6hlHEBKwqJ1GY9c6t3AOU,9403 +xformers/components/__init__.py,sha256=TsGNaORppkTOq5fhl7ndjY1KseTQ4v3tQ6cIxRzdndk,3173 +xformers/components/__pycache__/__init__.cpython-310.pyc,, +xformers/components/__pycache__/activations.cpython-310.pyc,, +xformers/components/__pycache__/input_projection.cpython-310.pyc,, +xformers/components/__pycache__/multi_head_dispatch.cpython-310.pyc,, +xformers/components/__pycache__/patch_embedding.cpython-310.pyc,, +xformers/components/__pycache__/residual.cpython-310.pyc,, +xformers/components/__pycache__/reversible.cpython-310.pyc,, +xformers/components/__pycache__/simplicial_embedding.cpython-310.pyc,, +xformers/components/activations.py,sha256=GcWgrvnOLsFjqK_Vw1Yx6UAMSCS_-Ntn-rGnwfaQjjE,1837 +xformers/components/attention/__init__.py,sha256=YsokN3HCoH15zR2vanKpgVrOhSUvdthwiP4l9woUGlw,3983 +xformers/components/attention/__pycache__/__init__.cpython-310.pyc,, +xformers/components/attention/__pycache__/_sputnik_sparse.cpython-310.pyc,, +xformers/components/attention/__pycache__/attention_mask.cpython-310.pyc,, +xformers/components/attention/__pycache__/attention_patterns.cpython-310.pyc,, +xformers/components/attention/__pycache__/base.cpython-310.pyc,, +xformers/components/attention/__pycache__/blocksparse.cpython-310.pyc,, +xformers/components/attention/__pycache__/compositional.cpython-310.pyc,, +xformers/components/attention/__pycache__/core.cpython-310.pyc,, +xformers/components/attention/__pycache__/favor.cpython-310.pyc,, +xformers/components/attention/__pycache__/fourier_mix.cpython-310.pyc,, +xformers/components/attention/__pycache__/global_tokens.cpython-310.pyc,, +xformers/components/attention/__pycache__/lambda_layer.cpython-310.pyc,, +xformers/components/attention/__pycache__/linformer.cpython-310.pyc,, +xformers/components/attention/__pycache__/local.cpython-310.pyc,, +xformers/components/attention/__pycache__/nystrom.cpython-310.pyc,, +xformers/components/attention/__pycache__/ortho.cpython-310.pyc,, +xformers/components/attention/__pycache__/pooling.cpython-310.pyc,, +xformers/components/attention/__pycache__/random.cpython-310.pyc,, +xformers/components/attention/__pycache__/scaled_dot_product.cpython-310.pyc,, +xformers/components/attention/__pycache__/sparsity_config.cpython-310.pyc,, +xformers/components/attention/__pycache__/utils.cpython-310.pyc,, +xformers/components/attention/__pycache__/visual.cpython-310.pyc,, +xformers/components/attention/_sputnik_sparse.py,sha256=bibk2S-Coatt45gO9fzLTaAgUMKXbRi5U0eR_-QAbnw,3164 +xformers/components/attention/attention_mask.py,sha256=Y0zyM5JDnFzfE0TGPx4BxEeQlqIPX-XnZpO_lskQn4w,4585 +xformers/components/attention/attention_patterns.py,sha256=mZ-s8R-DS6XrOP_-ExZZku7XWVKUOEVfUdHpB9D5kek,9945 +xformers/components/attention/base.py,sha256=bO9JlE9Kyl7UkZEAhFy7vdIvVwmq5ObAUb7sOkpQ2o8,3209 +xformers/components/attention/blocksparse.py,sha256=R61zSk1x1hejwEXBDRraoByDwVo2Sib6oUSuvLvO_E0,6725 +xformers/components/attention/compositional.py,sha256=H1prPE_o3Ymd_HbxBnjNWnz7vzYzA_9L5zgixpyWAlM,12954 +xformers/components/attention/core.py,sha256=QylojyuoHVyMRMCRAY-Twj_znR8dN64JTmWwDKNxKek,10782 +xformers/components/attention/favor.py,sha256=2IHlocaKOtgD1IIN1fGAcJsFkHqB02l1LQQBH1et6WE,6191 +xformers/components/attention/feature_maps/__init__.py,sha256=YT-iC15A-WTj9uy6etXp6tWNBDo3tolDjGDc-DYmzvA,592 +xformers/components/attention/feature_maps/__pycache__/__init__.cpython-310.pyc,, +xformers/components/attention/feature_maps/__pycache__/base.cpython-310.pyc,, +xformers/components/attention/feature_maps/__pycache__/softmax.cpython-310.pyc,, +xformers/components/attention/feature_maps/base.py,sha256=u4jzwUZludL_mFLYL6SzGFUFJGKkbHWYKE9iO5BzFzA,1672 +xformers/components/attention/feature_maps/softmax.py,sha256=tpI2yz8wk1fC1mv9zCHQpvc54ozSQTGJVBA37pvufqE,10606 +xformers/components/attention/fourier_mix.py,sha256=2HIZ0P7lWEPOgt4vGit9Ouc250wVOQES9kV_Y09rfoU,1181 +xformers/components/attention/global_tokens.py,sha256=JkN_jrNesrgjd6G_JzFdokGZLEYKDTnYHHc5IOmzYAw,4091 +xformers/components/attention/lambda_layer.py,sha256=a15tl5a9VMi1NaKCrEps2-Tb46RaA0R_M6L8zAA0YiI,2814 +xformers/components/attention/linformer.py,sha256=U-lqxbOdbx3CXnX7swpUvYaZXKJUNrGxJT4_jiyNH1c,2491 +xformers/components/attention/local.py,sha256=MDUq0UDzP4NpUxJZ5gOtPeLT0bRcE6vsA3iMnylqg7g,3815 +xformers/components/attention/nystrom.py,sha256=Za5veGImsv9j-rnA0fAHd61TGyNnfdY6I1O4AKWgje0,11405 +xformers/components/attention/ortho.py,sha256=2FWSaeqZ5m3QdyL_-BYZtNNXSL4xQVo9LXLN0si7bOQ,12124 +xformers/components/attention/pooling.py,sha256=-KGQxK9gSPxILByF_OsEaQfJqLr9GwbwjcgbYXDmOwM,2490 +xformers/components/attention/random.py,sha256=0bcvaYreWLNeW3bA3fvuJXIZINxXW0hqw4hsFauQ--M,4131 +xformers/components/attention/scaled_dot_product.py,sha256=vUcRVzQT3ol6hIqrTcCudlaCru1W1ZcZgjaTIkl7Lr4,4504 +xformers/components/attention/sparsity_config.py,sha256=lZLGeKVcqnKVGKqkjE5Av2yGIHB5zZvU9IwrR3UrvZo,41608 +xformers/components/attention/utils.py,sha256=04aEVvDG9LjzD9NfsQ3Kj-kZSnMFCppCFrJx6NPoV8g,3803 +xformers/components/attention/visual.py,sha256=EiMRHtf10q7EnwQF_gKdXCp-BsMHj0yHeDCC6SKFx2Q,2929 +xformers/components/feedforward/__init__.py,sha256=F0BzT32-YfREtcAes9gigsnsqB_-g2W04agpAZGrX-k,2556 +xformers/components/feedforward/__pycache__/__init__.cpython-310.pyc,, +xformers/components/feedforward/__pycache__/base.cpython-310.pyc,, +xformers/components/feedforward/__pycache__/conv_mlp.cpython-310.pyc,, +xformers/components/feedforward/__pycache__/fused_mlp.cpython-310.pyc,, +xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-310.pyc,, +xformers/components/feedforward/__pycache__/mlp.cpython-310.pyc,, +xformers/components/feedforward/base.py,sha256=eErhev0_tkEATf_6SUjbJa9rK3PMfTEQSa2VSkeIAtM,1498 +xformers/components/feedforward/conv_mlp.py,sha256=iZwK8xy8FN9KhqXde8WcbNitoEDA4ertlYo6E1ua6cE,3011 +xformers/components/feedforward/fused_mlp.py,sha256=JRJOFKC-cDmCmZ0dJz_uLd4yIEHz_C4yMMVP6VPmIjs,2579 +xformers/components/feedforward/mixture_of_experts.py,sha256=Zci7y55TvQRGpmARVzsj7db8Y-xs1NVPioXP3nEkhr0,5343 +xformers/components/feedforward/mlp.py,sha256=u5z2uYfgyddrz4o-xdylqyNo8onRg34CGacRGtkCsAQ,1308 +xformers/components/input_projection.py,sha256=p1iBJYRxvn95deoBBDafcEO9XvfyouahZHadoTGfn9U,3024 +xformers/components/multi_head_dispatch.py,sha256=Wc1sB6cO9LM0N2nHhq2C4rYYT3jYPhweK_DmyiWB5Zk,10169 +xformers/components/patch_embedding.py,sha256=TFRCB6L41cUJAe6egBYGeT5l138E-d1prVEYIeDNg9A,2269 +xformers/components/positional_embedding/__init__.py,sha256=sN-9Mk3SeY8snW70vqSedHYqmAwzi5hZjcOfHFNPJos,2755 +xformers/components/positional_embedding/__pycache__/__init__.cpython-310.pyc,, +xformers/components/positional_embedding/__pycache__/base.cpython-310.pyc,, +xformers/components/positional_embedding/__pycache__/param.cpython-310.pyc,, +xformers/components/positional_embedding/__pycache__/rotary.cpython-310.pyc,, +xformers/components/positional_embedding/__pycache__/sine.cpython-310.pyc,, +xformers/components/positional_embedding/__pycache__/vocab.cpython-310.pyc,, +xformers/components/positional_embedding/base.py,sha256=EUNGSNKYqRnr0za8TsfkSXTHW1Jkq_EZPf71ERhD4RA,972 +xformers/components/positional_embedding/param.py,sha256=mvvt_ZKrFL9ZogeuXO8RxFlKGIiaLBF1YE-XcwiZRfg,1544 +xformers/components/positional_embedding/rotary.py,sha256=_mYWMVhwNUdl62aMfrTOfY2W8P8fjqjXLcbrwB-aLEo,3274 +xformers/components/positional_embedding/sine.py,sha256=kOeakSCWG4KcI05C5Qouh9fpBHIunEI_-8EGDD-U_g8,1365 +xformers/components/positional_embedding/vocab.py,sha256=T-kvJyuU5HpB98_zuLUm2NLfhYsjW8zNnTCUqXhKi_4,1768 +xformers/components/residual.py,sha256=vte37ZjxEsqRZP55jYASM2aAW3EJUzjfRBnqPHhcJ7g,6125 +xformers/components/reversible.py,sha256=FnqbmAKiYbywxe1gyKvALUEvnhIggvVPWQp2hm0txIA,5260 +xformers/components/simplicial_embedding.py,sha256=G1RL4pNvZfCJ0iL3boao_V-KycCJUEjSClgoRvAZdLA,2323 +xformers/cpp_lib.json,sha256=Dy1tn61dm-CVYtEW5gKgzzjnDe3iwxBux_F7mgcqwKA,277 +xformers/factory/__init__.py,sha256=4DY_EmTQiTJLO1Ttk8H7cVNx92e6ZuFm8kW9E5Xod0A,617 +xformers/factory/__pycache__/__init__.cpython-310.pyc,, +xformers/factory/__pycache__/block_configs.cpython-310.pyc,, +xformers/factory/__pycache__/block_factory.cpython-310.pyc,, +xformers/factory/__pycache__/hydra_helper.cpython-310.pyc,, +xformers/factory/__pycache__/model_factory.cpython-310.pyc,, +xformers/factory/__pycache__/weight_init.cpython-310.pyc,, +xformers/factory/block_configs.py,sha256=Cy_lI_lDp8yfqnRxLrNpszGU1UzUljuw-1bIdNEgEpY,8175 +xformers/factory/block_factory.py,sha256=i6VSVQoku9qShZmdosawlLFyKCrrYEmP81ujGP9Fflc,12960 +xformers/factory/hydra_helper.py,sha256=5bP3myNgAai_J_Se69wVSkE1C5HGE2HZU_TW2xstgCA,1312 +xformers/factory/model_factory.py,sha256=1S4l9G-LHJXMokNOisd7u654xYQ8SrlvEKnz9Ye-fXw,11941 +xformers/factory/weight_init.py,sha256=DUZBhDIbkbtdk4GYzqp-Xu-yXjvWFueD3fOx4OhvqTE,9988 +xformers/helpers/__init__.py,sha256=3zcfxEo1NykSJOddz1xp2I_ftiTsyaT0xPe-zpRZP3c,263 +xformers/helpers/__pycache__/__init__.cpython-310.pyc,, +xformers/helpers/__pycache__/hierarchical_configs.cpython-310.pyc,, +xformers/helpers/__pycache__/test_utils.cpython-310.pyc,, +xformers/helpers/__pycache__/timm_sparse_attention.cpython-310.pyc,, +xformers/helpers/hierarchical_configs.py,sha256=sWNlYaxOI905dszrRlD8NNoj-gvXgtQWpygMyeZKsYM,3924 +xformers/helpers/test_utils.py,sha256=OMD0494dmlWHmU-xH-l7ssRimZkufofxrbeZj0qsly0,800 +xformers/helpers/timm_sparse_attention.py,sha256=E22N74lKwmAkwRX1SI4SZEpvxgBN1ZeJIYoVLvHnmro,1528 +xformers/info.py,sha256=nyZZsnPPt3NYbnlTUqX_Ua1_McvGePDz_5RBGqvBXSo,2442 +xformers/ops/__init__.py,sha256=ekHzhR4eKZCR39RbdSRXEnM4IBJh8a9g0jRXIReEkvI,2573 +xformers/ops/__pycache__/__init__.cpython-310.pyc,, +xformers/ops/__pycache__/common.cpython-310.pyc,, +xformers/ops/__pycache__/indexing.cpython-310.pyc,, +xformers/ops/__pycache__/rmsnorm.cpython-310.pyc,, +xformers/ops/__pycache__/rope_padded.cpython-310.pyc,, +xformers/ops/__pycache__/swiglu_op.cpython-310.pyc,, +xformers/ops/__pycache__/unbind.cpython-310.pyc,, +xformers/ops/common.py,sha256=z7MUx2jnGiAfX1wlZ0LMJOeVZ-niN5RvUbJyQUo77jw,3752 +xformers/ops/fmha/__init__.py,sha256=9JD1-A9zTv9ml3c6pKXmxam6QwNoug77XJBloA-ZpwE,16319 +xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/common.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/flash.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/triton.cpython-310.pyc,, +xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc,, +xformers/ops/fmha/attn_bias.py,sha256=eNnm33ejYy8F_f2nGmWaKBMD58vZY-iQoEUp2Fx1naE,28159 +xformers/ops/fmha/common.py,sha256=ltWJ5qhZa49EZt46RvBNrIioQMsobJ0BRaPiRfhlSLQ,19635 +xformers/ops/fmha/cutlass.py,sha256=xVbql2Eh2mv0jfS1Zgj8-65HvuQDCHcjX4rmf4eZZlQ,16798 +xformers/ops/fmha/decoder.py,sha256=onOwhfXCm3yfsgWf-hALoOq7yjZ80225NMMLPHxOwMA,3821 +xformers/ops/fmha/dispatch.py,sha256=fZoTGN3h4nLXKNDovmRk3j_5zD16o6jyvF3xWIvt_4E,5148 +xformers/ops/fmha/flash.py,sha256=yAShHyq1RbEzhdBkTkMm458xuKLEQ185VLuZYAU0VgM,21649 +xformers/ops/fmha/small_k.py,sha256=8sLm0muZTfo1b660FgPWKh6P9XhgoxxOYKnvrMq1UgE,6337 +xformers/ops/fmha/triton.py,sha256=bzeUSBUPMTW2_sc3tTtco_sJI6PA6dk__-gW92T4FVM,7552 +xformers/ops/fmha/triton_splitk.py,sha256=qGEviCXlOg5FudIiLoN-ahw-qUkZpijX8IrFti2bcHU,26789 +xformers/ops/indexing.py,sha256=inqijffB3P3OmoU33vNlX1An3Vn7jl_lOkZBoefYc6Q,6654 +xformers/ops/rmsnorm.py,sha256=4ZtculAz9gzbWNBaa7R4rx6yLCOhCod6RMBi2NF4lYY,3356 +xformers/ops/rope_padded.py,sha256=MpNviOcChMH5gQcbMWI7JBi4BMLwRsy-6AAZaP1E0m0,6728 +xformers/ops/swiglu_op.py,sha256=ztKyC6736x_nUFTeBgEwnT0YlRwynAYbjwiF2uIe0L4,15168 +xformers/ops/triton/__init__.py,sha256=BBoPvMfJjZ0Fi25JCRS8sPYUfLbHneLncV3JeGDIGHg,198 +xformers/ops/triton/__pycache__/__init__.cpython-310.pyc,, +xformers/ops/triton/__pycache__/rmsnorm_kernels.cpython-310.pyc,, +xformers/ops/triton/__pycache__/rope_padded_kernels.cpython-310.pyc,, +xformers/ops/triton/rmsnorm_kernels.py,sha256=KSrZSW5HtCOhtX2fkHg5gqV_fxYl5aGgFRGbaoitjos,4929 +xformers/ops/triton/rope_padded_kernels.py,sha256=nMjjE1d3KD_g9BcMea8kRqrZjQmNGYnsMEiVqCJDl30,5380 +xformers/ops/unbind.py,sha256=UU4H3aQgCDI5LBzAQxeGa23RnXWvntdFG2fi7oEFHXo,3556 +xformers/profiler/__init__.py,sha256=hati8pa77wUdoL0ouyGOxK1bdQQo-mXKp656ftZcmgo,503 +xformers/profiler/__pycache__/__init__.cpython-310.pyc,, +xformers/profiler/__pycache__/api.cpython-310.pyc,, +xformers/profiler/__pycache__/device_limits.cpython-310.pyc,, +xformers/profiler/__pycache__/profiler.cpython-310.pyc,, +xformers/profiler/__pycache__/slow_ops_profiler.cpython-310.pyc,, +xformers/profiler/api.py,sha256=1qNHFJfvbLOQspAa9JcsO5xGMFijQGICi7Xk-87cDwI,2826 +xformers/profiler/device_limits.py,sha256=5eKdURXQVX5nEkwh546bzMuXpSAuB4xfp0krBHLkvO0,3577 +xformers/profiler/profiler.py,sha256=e_VzR69NBr7QFxlsjuw-Yh5Reh5dvcy5MJzlTmFYm-M,12106 +xformers/profiler/slow_ops_profiler.py,sha256=AvkJFy9A_g2zQBk6iSAMCEuqC8CmNvmJ0IVxuP05ErY,16378 +xformers/sparse/__init__.py,sha256=NpDMLawDg0ra3mR7LpNnkua8KOpeY6h2AL4or88rGws,317 +xformers/sparse/__pycache__/__init__.cpython-310.pyc,, +xformers/sparse/__pycache__/_csr_ops.cpython-310.pyc,, +xformers/sparse/__pycache__/blocksparse_tensor.cpython-310.pyc,, +xformers/sparse/__pycache__/csr_tensor.cpython-310.pyc,, +xformers/sparse/__pycache__/utils.cpython-310.pyc,, +xformers/sparse/_csr_ops.py,sha256=p8bGD91Y52LuYieY97ZCtlOpYFNnnPRjXDuBoK5_Bp0,4873 +xformers/sparse/blocksparse_tensor.py,sha256=7VvyFZysozfdsMDyFr5HS-zjqy9zOzc5VlRz9KndOXQ,11557 +xformers/sparse/csr_tensor.py,sha256=jBE1ABhJm8y_zU-LbdwYOrxCxvqB6UuQ1XdgV-uEwZw,14276 +xformers/sparse/utils.py,sha256=f52XNuTj8gsi5YZABkYeMykMWRGmBnHjIkXM4gzXTdM,4274 +xformers/test.py,sha256=BBoPvMfJjZ0Fi25JCRS8sPYUfLbHneLncV3JeGDIGHg,198 +xformers/triton/__init__.py,sha256=ru7vvZUC_6fG5aEUA2wZygmsybzvCYnNqkLaS-AV0uM,803 +xformers/triton/__pycache__/__init__.cpython-310.pyc,, +xformers/triton/__pycache__/dropout.cpython-310.pyc,, +xformers/triton/__pycache__/fused_linear_layer.cpython-310.pyc,, +xformers/triton/__pycache__/k_activations.cpython-310.pyc,, +xformers/triton/__pycache__/k_dropout.cpython-310.pyc,, +xformers/triton/__pycache__/k_fused_matmul_bw.cpython-310.pyc,, +xformers/triton/__pycache__/k_fused_matmul_fw.cpython-310.pyc,, +xformers/triton/__pycache__/k_layer_norm.cpython-310.pyc,, +xformers/triton/__pycache__/k_softmax.cpython-310.pyc,, +xformers/triton/__pycache__/k_sum.cpython-310.pyc,, +xformers/triton/__pycache__/layer_norm.cpython-310.pyc,, +xformers/triton/__pycache__/softmax.cpython-310.pyc,, +xformers/triton/__pycache__/sum_strided.cpython-310.pyc,, +xformers/triton/__pycache__/utils.cpython-310.pyc,, +xformers/triton/__pycache__/vararg_kernel.cpython-310.pyc,, +xformers/triton/dropout.py,sha256=91rbONmCbgzdaW6PfLkbh0inTzIaji77nM0HlbqP5No,7615 +xformers/triton/fused_linear_layer.py,sha256=aJkIfOGnfT7q0vWwjbck5aYdwnBGYrCeh0weVFYhBHY,3931 +xformers/triton/k_activations.py,sha256=ac_e4GJW7_slFRMDi5eCZgXKsKH_MtZW6UXrvEPk1II,3552 +xformers/triton/k_dropout.py,sha256=vByszmQ8qL-MTieTcfQt3NDYoUcOYUl_9S5vE_h39tY,5941 +xformers/triton/k_fused_matmul_bw.py,sha256=6fbJbjocX2Dp5uTkJR49rQNiOkSwjwordKr3BQVBLvE,5157 +xformers/triton/k_fused_matmul_fw.py,sha256=0iCuhFlH7dooChe7eTDZm0ajPMxXg_r7uiPH4rYUwMY,8426 +xformers/triton/k_layer_norm.py,sha256=A-ZZWpR6UaUVkA_yDa4zCpE76WCL4RUTVxyoLdf5wpo,5059 +xformers/triton/k_softmax.py,sha256=uBzhrE7fjk-zbA5_COrcHWmAfQ3svithZevjFVxKdfQ,5068 +xformers/triton/k_sum.py,sha256=zXh4xBOwTP8oVVSt2NgFYJliK-jzxZEsX0IiAFhhOvw,1570 +xformers/triton/layer_norm.py,sha256=OKe-5jfidsW0HjI3i91QjDZNzLcRYqhN_kgVuN77mNA,7659 +xformers/triton/softmax.py,sha256=9ZTmNozqs36VJak6aIwgIYUa1uT23IoDusBmLzd-tCw,6377 +xformers/triton/sum_strided.py,sha256=-2nP-YKqeTxlPVNL-EUQC8-hdBM7Hcl2RzTK__XGysk,1457 +xformers/triton/utils.py,sha256=a8jKA5rnmT7ji24IjPi-lH2ddBDmxu_kZdUdG_Vi3wU,1187 +xformers/triton/vararg_kernel.py,sha256=rNRVcq1AvgRG0o7Fw9w-CLs7_MQ6bQfGibcjRt2L47I,5870 +xformers/utils.py,sha256=CPX2p-5yQSMnQhpkxqaLqlaiXoFjnraZot1TYSpaLzM,2748 +xformers/version.py,sha256=59lN4YV7CZl9SfIHHZdTuHukT2WgB33mXdw-RwSjIn4,36 diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/REQUESTED b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/REQUESTED new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/WHEEL b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/WHEEL new file mode 100644 index 0000000..591b31a --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.44.0) +Root-Is-Purelib: false +Tag: cp310-cp310-linux_x86_64 + diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/direct_url.json b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/direct_url.json new file mode 100644 index 0000000..5f8a8fa --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/direct_url.json @@ -0,0 +1 @@ +{"archive_info": {"hash": "sha256=12af5aa28cd7e2780815f88f9012d8729614d6e544c131b874f21901989a4f2f", "hashes": {"sha256": "12af5aa28cd7e2780815f88f9012d8729614d6e544c131b874f21901989a4f2f"}}, "url": "file:///usr/local/apps_whl/xformers-0.0.22%2Bcorex.4.1.2-cp310-cp310-linux_x86_64.whl"} \ No newline at end of file diff --git a/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/top_level.txt b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/top_level.txt new file mode 100644 index 0000000..56e3c82 --- /dev/null +++ b/pkgs/xformers-0.0.22+corex.4.1.2.dist-info/top_level.txt @@ -0,0 +1 @@ +xformers diff --git a/pkgs/xformers/_C.so b/pkgs/xformers/_C.so new file mode 100755 index 0000000..3e7046b Binary files /dev/null and b/pkgs/xformers/_C.so differ diff --git a/pkgs/xformers/_C_flashattention.so b/pkgs/xformers/_C_flashattention.so new file mode 100755 index 0000000..8fd18c2 Binary files /dev/null and b/pkgs/xformers/_C_flashattention.so differ diff --git a/pkgs/xformers/__init__.py b/pkgs/xformers/__init__.py new file mode 100644 index 0000000..dee15fd --- /dev/null +++ b/pkgs/xformers/__init__.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import torch + +from . import _cpp_lib +from .checkpoint import checkpoint, list_operators # noqa: E402, F401 + +try: + from .version import __version__ # noqa: F401 +except ImportError: + __version__ = "0.0.0" + + +logger = logging.getLogger("xformers") + +_has_cpp_library: bool = _cpp_lib._cpp_library_load_exception is None + +_is_opensource: bool = True + + +def compute_once(func): + value = None + + def func_wrapper(): + nonlocal value + if value is None: + value = func() + return value + + return func_wrapper + + +@compute_once +def _is_triton_available(): + if not torch.cuda.is_available(): + return False + if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1": + return False + try: + from xformers.triton.softmax import softmax as triton_softmax # noqa + + return True + except (ImportError, AttributeError) as e: + # logger.warning( + # f"A matching Triton is not available, some optimizations will not be enabled.\nError caught was: {e}" + # ) + return False + + +@compute_once +def get_python_lib(): + return torch.library.Library("xformers_python", "DEF") + + +# end of file diff --git a/pkgs/xformers/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..ef2a6be Binary files /dev/null and b/pkgs/xformers/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/_cpp_lib.cpython-310.pyc b/pkgs/xformers/__pycache__/_cpp_lib.cpython-310.pyc new file mode 100644 index 0000000..9eeb26a Binary files /dev/null and b/pkgs/xformers/__pycache__/_cpp_lib.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/_deprecation_warning.cpython-310.pyc b/pkgs/xformers/__pycache__/_deprecation_warning.cpython-310.pyc new file mode 100644 index 0000000..7af4077 Binary files /dev/null and b/pkgs/xformers/__pycache__/_deprecation_warning.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/checkpoint.cpython-310.pyc b/pkgs/xformers/__pycache__/checkpoint.cpython-310.pyc new file mode 100644 index 0000000..e0bf8f3 Binary files /dev/null and b/pkgs/xformers/__pycache__/checkpoint.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/info.cpython-310.pyc b/pkgs/xformers/__pycache__/info.cpython-310.pyc new file mode 100644 index 0000000..9244e4a Binary files /dev/null and b/pkgs/xformers/__pycache__/info.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/test.cpython-310.pyc b/pkgs/xformers/__pycache__/test.cpython-310.pyc new file mode 100644 index 0000000..77da7d6 Binary files /dev/null and b/pkgs/xformers/__pycache__/test.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/utils.cpython-310.pyc b/pkgs/xformers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..597659b Binary files /dev/null and b/pkgs/xformers/__pycache__/utils.cpython-310.pyc differ diff --git a/pkgs/xformers/__pycache__/version.cpython-310.pyc b/pkgs/xformers/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000..ef8089b Binary files /dev/null and b/pkgs/xformers/__pycache__/version.cpython-310.pyc differ diff --git a/pkgs/xformers/_cpp_lib.py b/pkgs/xformers/_cpp_lib.py new file mode 100644 index 0000000..4eb6fd9 --- /dev/null +++ b/pkgs/xformers/_cpp_lib.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +import json +import logging +import os +import platform +from typing import Any, Dict, Optional + +import torch + +logger = logging.getLogger("xformers") + +UNAVAILABLE_FEATURES_MSG = ( + " Memory-efficient attention, SwiGLU, sparse and more won't be available." +) + + +@dataclasses.dataclass +class _BuildInfo: + metadata: Dict[str, Any] + + @property + def cuda_version(self) -> Optional[int]: + return self.metadata["version"]["cuda"] + + @property + def torch_version(self) -> str: + return self.metadata["version"]["torch"] + + @property + def python_version(self) -> str: + return self.metadata["version"]["python"] + + @property + def flash_version(self) -> str: + return self.metadata["version"].get("flash", "0.0.0") + + @property + def build_env(self) -> Dict[str, Any]: + return self.metadata["env"] + + +class xFormersWasNotBuiltException(Exception): + def __str__(self) -> str: + return ( + "Need to compile C++ extensions to use all xFormers features.\n" + " Please install xformers properly " + "(see https://github.com/facebookresearch/xformers#installing-xformers)\n" + + UNAVAILABLE_FEATURES_MSG + ) + + +class xFormersInvalidLibException(Exception): + def __init__(self, build_info: Optional[_BuildInfo]) -> None: + self.build_info = build_info + + def __str__(self) -> str: + if self.build_info is None: + msg = "xFormers was built for a different version of PyTorch or Python." + else: + msg = f"""xFormers was built for: + PyTorch {self.build_info.torch_version} with CUDA {self.build_info.cuda_version} (you have {torch.__version__}) + Python {self.build_info.python_version} (you have {platform.python_version()})""" + return ( + "xFormers can't load C++/CUDA extensions. " + + msg + + "\n Please reinstall xformers " + "(see https://github.com/facebookresearch/xformers#installing-xformers)\n" + + UNAVAILABLE_FEATURES_MSG + ) + + +def _register_extensions(): + import importlib + import os + + import torch + + # load the custom_op_library and register the custom ops + lib_dir = os.path.dirname(__file__) + if os.name == "nt": + # Register the main torchvision library location on the default DLL path + import ctypes + import sys + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + if with_load_library_flags: + kernel32.AddDllDirectory.restype = ctypes.c_void_p + + if sys.version_info >= (3, 8): + os.add_dll_directory(lib_dir) + elif with_load_library_flags: + res = kernel32.AddDllDirectory(lib_dir) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' + raise err + + kernel32.SetErrorMode(prev_error_mode) + + loader_details = ( + importlib.machinery.ExtensionFileLoader, + importlib.machinery.EXTENSION_SUFFIXES, + ) + + extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) + ext_specs = extfinder.find_spec("_C") + if ext_specs is None: + raise xFormersWasNotBuiltException() + cpp_lib_json = os.path.join(lib_dir, "cpp_lib.json") + with open(cpp_lib_json, "r") as fp: + build_metadata = _BuildInfo(json.load(fp)) + try: + torch.ops.load_library(ext_specs.origin) + except OSError as exc: + raise xFormersInvalidLibException(build_metadata) from exc + return build_metadata + + +_cpp_library_load_exception = None +_build_metadata: Optional[_BuildInfo] = None + +try: + _build_metadata = _register_extensions() +except (xFormersInvalidLibException, xFormersWasNotBuiltException) as e: + ENV_VAR_FOR_DETAILS = "XFORMERS_MORE_DETAILS" + if os.environ.get(ENV_VAR_FOR_DETAILS, False): + logger.warning(f"WARNING[XFORMERS]: {e}", exc_info=e) + else: + logger.warning( + f"WARNING[XFORMERS]: {e}\n Set {ENV_VAR_FOR_DETAILS}=1 for more details" + ) + _cpp_library_load_exception = e + +_built_with_cuda = ( + _build_metadata is not None and _build_metadata.cuda_version is not None +) diff --git a/pkgs/xformers/_deprecation_warning.py b/pkgs/xformers/_deprecation_warning.py new file mode 100644 index 0000000..505ef15 --- /dev/null +++ b/pkgs/xformers/_deprecation_warning.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + + +def deprecated_function(self): + name = repr(self) # self.__name__ + msg = f"{name} is deprecated and is not maintained anymore. It might be removed in a future version of xFormers" + warnings.warn(msg, FutureWarning, stacklevel=2) diff --git a/pkgs/xformers/_flash_attn/__init__.py b/pkgs/xformers/_flash_attn/__init__.py new file mode 100644 index 0000000..9785dc0 --- /dev/null +++ b/pkgs/xformers/_flash_attn/__init__.py @@ -0,0 +1,8 @@ +__version__ = "2.5.8" + +from flash_attn.flash_attn_interface import flash_attn_func +from flash_attn.flash_attn_interface import flash_attn_kvpacked_func +from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_func diff --git a/pkgs/xformers/_flash_attn/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3b44c9f Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/bert_padding.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/bert_padding.cpython-310.pyc new file mode 100644 index 0000000..92ee143 Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/bert_padding.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-310.pyc new file mode 100644 index 0000000..1cd4091 Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/flash_attn_interface.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-310.pyc new file mode 100644 index 0000000..fdb35fb Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/flash_attn_triton.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-310.pyc new file mode 100644 index 0000000..fbbe5e2 Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/flash_attn_triton_og.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-310.pyc new file mode 100644 index 0000000..5909c14 Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/flash_blocksparse_attention.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-310.pyc new file mode 100644 index 0000000..0ab1e83 Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/flash_blocksparse_attn_interface.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/__pycache__/fused_softmax.cpython-310.pyc b/pkgs/xformers/_flash_attn/__pycache__/fused_softmax.cpython-310.pyc new file mode 100644 index 0000000..6c5a0a3 Binary files /dev/null and b/pkgs/xformers/_flash_attn/__pycache__/fused_softmax.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/bert_padding.py b/pkgs/xformers/_flash_attn/bert_padding.py new file mode 100644 index 0000000..b071416 --- /dev/null +++ b/pkgs/xformers/_flash_attn/bert_padding.py @@ -0,0 +1,132 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F + +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, 'b ... -> b (...)'), 0, + repeat(indices, 'z -> z d', d=second_dim)).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, 'b ... -> b (...)') + grad_input = torch.zeros([ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, dtype=grad_output.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, 'z -> z d', d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, + dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + indices, = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + indices, = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return (index_first_axis(rearrange(hidden_states, 'b s ... -> (b s) ...'), indices), indices, + cu_seqlens, max_seqlen_in_batch) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, '(b s) ... -> b s ...', b=batch) diff --git a/pkgs/xformers/_flash_attn/flash_attn_interface.py b/pkgs/xformers/_flash_attn/flash_attn_interface.py new file mode 100644 index 0000000..4771954 --- /dev/null +++ b/pkgs/xformers/_flash_attn/flash_attn_interface.py @@ -0,0 +1,966 @@ +import os +import torch +import torch.nn as nn + +import flash_attn_2_cuda as flash_attn_cuda +from einops import rearrange +import warnings + +global enable_ixdnn +enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0' + +def _get_block_size(device, head_dim, is_dropout, is_causal): + # This should match the block sizes in the CUDA kernel + assert head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if head_dim <= 32: + return 128, 128 + if head_dim <= 64: + return (128, 128) if not is_dropout else (128, 64) + elif head_dim <= 96: + return (64, 64) if (is_sm8x and is_causal) else (128, 64) + elif head_dim <= 128: + if is_sm8x: + return (64, 64) if (not is_dropout and is_causal) else (128, 32) + else: + return 128, (64 if not is_dropout else 32) + elif head_dim <= 160: + if is_sm8x: + return (128, 64) if not is_causal else (64, 64) + else: + return 128, 32 + elif head_dim <= 192: + return (128, 64) if not is_dropout else (64, 64) + elif head_dim <= 224: + return (128, 64) if (is_sm80 or is_sm90) else (64, 64) + elif head_dim <= 256: + return (128, 64) if is_sm80 else (64, 64) + + +def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd( + q, k, v, None, dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode, None + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask + + +def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax, use_alibi, alibi_mode): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, use_alibi, alibi_mode, None + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, q, k, v, out_padded, softmax_lse, S_dmask + + +def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, + dropout_p, softmax_scale, causal, use_alibi, alibi_mode): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( + dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p, softmax_scale, causal, use_alibi, alibi_mode, None + ) + return dq, dk, dv, softmax_d + + +def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, use_alibi, alibi_mode): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, use_alibi, alibi_mode, None + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dq, dk, dv, softmax_d + +def _flash_attn_forward_ixdnn(q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p, + causal, return_softmax, use_alibi, alibi_mode, imp_mode, window_size): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.fwd_ixdnn( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, dropout_p, causal, return_softmax, + use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask + +def _flash_attn_backward_ixdnn(dout, q, k, v, out, softmax_lse, dq, dk, dv, + cu_seqlens_q, cu_seqlens_k, dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, = flash_attn_cuda.bwd_ixdnn( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, dropout_p, + causal, use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None + ) + return dq, dk, dv, softmax_d + +def _flash_attn_varlen_forward_ixdnn(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_cuda.varlen_fwd_ixdnn( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + causal, use_alibi, alibi_mode, imp_mode, window_size[0], window_size[1], None + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask + +def _flash_attn_varlen_backward_ixdnn(dout, q, k, v, out, softmax_lse, dq, dk, dv, + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, causal, use_alibi, alibi_mode, imp_mode, window_size): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d = flash_attn_cuda.varlen_bwd_ixdnn( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, dropout_p, causal, use_alibi, alibi_mode, imp_mode, + window_size[0], window_size[1], None + ) + return dq, dk, dv, softmax_d + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if alibi_slopes is not None and use_alibi == False: + warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning) + use_alibi = True + imp_mode = 1 + if enable_ixdnn: + cu_seqlens = torch.full((qkv.shape[0],), qkv.shape[1], dtype=torch.int32, device=qkv.device) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], cu_seqlens, cu_seqlens, dropout_p, + causal=causal, return_softmax=return_softmax and dropout_p > 0, + use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode, + window_size=window_size + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + else: + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], dropout_p, softmax_scale, + causal=causal, return_softmax=return_softmax and dropout_p > 0, + use_alibi=use_alibi, alibi_mode=alibi_mode + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.use_alibi = use_alibi + ctx.alibi_mode = alibi_mode + ctx.imp_mode = imp_mode + ctx.window_size = window_size + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + if enable_ixdnn: + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + if enable_ixdnn: + _flash_attn_backward_ixdnn( + dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + cu_seqlens, cu_seqlens, ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, + ctx.window_size + ) + else: + _flash_attn_backward( + dout, q, k, v, out, softmax_lse, dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode + ) + dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if alibi_slopes is not None and use_alibi == False: + warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning) + use_alibi = True + imp_mode = 1 + if enable_ixdnn: + cu_seqlens = torch.full((cu_seqlens.numel(),), max_seqlen, dtype=torch.int32, device=cu_seqlens.device) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn( + qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode, + window_size=window_size + ) + else: + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + qkv[:, 0], qkv[:, 1], qkv[:, 2], cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + use_alibi=use_alibi, alibi_mode=alibi_mode + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.use_alibi = use_alibi + ctx.alibi_mode = alibi_mode + ctx.imp_mode = imp_mode + ctx.window_size = window_size + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + if enable_ixdnn: + _flash_attn_varlen_backward_ixdnn( + dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], + cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, + ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, + ctx.window_size + ) + else: + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dqkv[:, 0], dqkv[:, 1], dqkv[:, 2], + cu_seqlens, cu_seqlens, ctx.max_seqlen, ctx.max_seqlen, + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode + ) + dqkv = dqkv[..., :dout.shape[-1]] # We could have padded the head dimension + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if alibi_slopes is not None and use_alibi == False: + warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning) + use_alibi = True + imp_mode = 1 + if enable_ixdnn: + cu_seqlens_q = torch.full((q.shape[0],), q.shape[1], dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.full((kv.shape[0],), kv.shape[1], dtype=torch.int32, device=kv.device) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn( + q, kv[:, :, 0], kv[:, :, 1], cu_seqlens_q, cu_seqlens_k, dropout_p, causal=causal, + return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi, + alibi_mode=alibi_mode, imp_mode=imp_mode, window_size=window_size + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + else: + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi, + alibi_mode=alibi_mode + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.use_alibi = use_alibi + ctx.alibi_mode = alibi_mode + ctx.imp_mode = imp_mode + ctx.window_size = window_size + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + if enable_ixdnn: + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + if enable_ixdnn: + _flash_attn_backward_ixdnn( + dout, q, k, v, out, softmax_lse, dq, dkv[:, :, 0], dkv[:, :, 1], + cu_seqlens_q, cu_seqlens_k, ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, ctx.window_size + ) + else: + _flash_attn_backward( + dout, q, k, v, out, softmax_lse, + dq, dkv[:, :, 0], dkv[:, :, 1], ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode + ) + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dkv, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, + use_alibi, alibi_mode, imp_mode): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if alibi_slopes is not None and use_alibi == False: + warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning) + use_alibi = True + imp_mode = 1 + if enable_ixdnn: + cu_seqlens_q = torch.full((cu_seqlens_q.numel(),), max_seqlen_q, dtype=torch.int32, device=cu_seqlens_q.device) + cu_seqlens_k = torch.full((cu_seqlens_k.numel(),), max_seqlen_k, dtype=torch.int32, device=cu_seqlens_k.device) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn( + q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode, + window_size=window_size + ) + else: + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + q, kv[:, 0], kv[:, 1], cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + use_alibi=use_alibi, alibi_mode=alibi_mode + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.use_alibi = use_alibi + ctx.alibi_mode = alibi_mode + ctx.imp_mode = imp_mode + ctx.window_size = window_size + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + if enable_ixdnn: + _flash_attn_varlen_backward_ixdnn( + dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], + cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, + ctx.dropout_p, ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, + ctx.window_size + ) + else: + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dq, dkv[:, 0], dkv[:, 1], + cu_seqlens_q, cu_seqlens_k, ctx.max_seqlen_q, ctx.max_seqlen_k, + ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode + ) + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, use_alibi, alibi_mode, imp_mode): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if alibi_slopes is not None and use_alibi == False: + warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning) + use_alibi = True + imp_mode = 1 + if enable_ixdnn: + cu_seqlens_q = torch.full((q.shape[0],), q.shape[1], dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.full((k.shape[0],), k.shape[1], dtype=torch.int32, device=k.device) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward_ixdnn( + q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p, causal=causal, + return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi, + alibi_mode = alibi_mode, imp_mode = imp_mode, window_size=window_size + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + else: + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + q, k, v, dropout_p, softmax_scale, causal=causal, + return_softmax=return_softmax and dropout_p > 0, use_alibi=use_alibi, + alibi_mode = alibi_mode + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.use_alibi = use_alibi + ctx.alibi_mode = alibi_mode + ctx.imp_mode = imp_mode + ctx.window_size = window_size + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + if enable_ixdnn: + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + if enable_ixdnn: + _flash_attn_backward_ixdnn( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, cu_seqlens_q, cu_seqlens_k, ctx.dropout_p, + ctx.causal, ctx.use_alibi, ctx.alibi_mode, ctx.imp_mode, + ctx.window_size + ) + else: + _flash_attn_backward( + dout, q, k, v, out, softmax_lse, + dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, ctx.use_alibi, ctx.alibi_mode + ) + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., :dout.shape[-1]] + dv = dv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, window_size, alibi_slopes, deterministic, return_softmax, + use_alibi, alibi_mode, imp_mode): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if alibi_slopes is not None and use_alibi == False: + warnings.warn("Parameter 'alibi_slopes' is not supported, automatically switching to the supported parameters 'use_alibi=True' and 'alibi_mode=1'", UserWarning) + use_alibi = True + imp_mode = 1 + if enable_ixdnn: + cu_seqlens_q = torch.full((cu_seqlens_q.numel(),), max_seqlen_q, dtype=torch.int32, device=cu_seqlens_q.device) + cu_seqlens_k = torch.full((cu_seqlens_k.numel(),), max_seqlen_k, dtype=torch.int32, device=cu_seqlens_k.device) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward_ixdnn( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, causal=causal, use_alibi=use_alibi, alibi_mode = alibi_mode, imp_mode=imp_mode, + window_size=window_size + ) + else: + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, + use_alibi = use_alibi, alibi_mode = alibi_mode + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.use_alibi = use_alibi + ctx.alibi_mode = alibi_mode + ctx.imp_mode = imp_mode + ctx.window_size = window_size + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + if enable_ixdnn: + _flash_attn_varlen_backward_ixdnn( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.causal, ctx.use_alibi, + ctx.alibi_mode, ctx.imp_mode, ctx.window_size + ) + else: + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, + ctx.use_alibi, ctx.alibi_mode + ) + + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., :dout.shape[-1]] + dv = dv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, + deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to + the attention score of query i and key j. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must + be True. The calculation formula(QK^T + bias * slops_m) is shown as follows, + q1*k1 0 + q2*k1 q2*k2 -1 0 + q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m + q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0 + q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0 + + alibi_mode: int. The bias mode of ALiBi, default to 1. + alibi_mode=0: alibi_mode=1: + 0 -√0 + -1 0 -√1 -√0 + -2 -1 0 -√2 -√1 -√0 + -3 -2 -1 0 -√3 -√2 -√1 -√0 + -4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0 + -5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0 + imp_mode: int. Support two modes of backward implementation, default to 0. + imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel, + which will occupy additional memory. This mode usually has better performance. + imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block, + which will not occupy additional memory. There will be a slight performance loss in this mode. + This mode can be considered in memory-limited scenarios. + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnQKVPackedFunc.apply(qkv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode) + + +def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, + deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + kv: (batch_size, seqlen, 2, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must + be True. The calculation formula(QK^T + bias * slops_m) is shown as follows, + q1*k1 0 + q2*k1 q2*k2 -1 0 + q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m + q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0 + q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0 + + alibi_mode: int. The bias mode of ALiBi, default to 1. + alibi_mode=0: alibi_mode=1: + 0 -√0 + -1 0 -√1 -√0 + -2 -1 0 -√2 -√1 -√0 + -3 -2 -1 0 -√3 -√2 -√1 -√0 + -4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0 + -5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0 + imp_mode: int. Support two modes of backward implementation, default to 0. + imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel, + which will occupy additional memory. This mode usually has better performance. + imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block, + which will not occupy additional memory. There will be a slight performance loss in this mode. + This mode can be considered in memory-limited scenarios. + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnKVPackedFunc.apply(q, kv, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode) + +def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), alibi_slopes=None, + deterministic=False, return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must + be True. The calculation formula(QK^T + bias * slops_m) is shown as follows, + q1*k1 0 + q2*k1 q2*k2 -1 0 + q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m + q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0 + q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0 + + alibi_mode: int. The bias mode of ALiBi, default to 1. + alibi_mode=0: alibi_mode=1: + 0 -√0 + -1 0 -√1 -√0 + -2 -1 0 -√2 -√1 -√0 + -3 -2 -1 0 -√3 -√2 -√1 -√0 + -4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0 + -5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0 + imp_mode: int. Support two modes of backward implementation, default to 0. + imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel, + which will occupy additional memory. This mode usually has better performance. + imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block, + which will not occupy additional memory. There will be a slight performance loss in this mode. + This mode can be considered in memory-limited scenarios. + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply(q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode) + + +def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, + causal=False, window_size=(-1, -1), alibi_slopes=None, deterministic=False, + return_attn_probs=False, use_alibi=False, alibi_mode=1, imp_mode=0): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must + be True. The calculation formula(QK^T + bias * slops_m) is shown as follows, + q1*k1 0 + q2*k1 q2*k2 -1 0 + q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m + q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0 + q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0 + + alibi_mode: int. The bias mode of ALiBi, default to 1. + alibi_mode=0: alibi_mode=1: + 0 -√0 + -1 0 -√1 -√0 + -2 -1 0 -√2 -√1 -√0 + -3 -2 -1 0 -√3 -√2 -√1 -√0 + -4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0 + -5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0 + imp_mode: int. Support two modes of backward implementation, default to 0. + imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel, + which will occupy additional memory. This mode usually has better performance. + imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block, + which will not occupy additional memory. There will be a slight performance loss in this mode. + This mode can be considered in memory-limited scenarios. + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenQKVPackedFunc.apply( + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, window_size, + alibi_slopes, deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode + ) + + +def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), + alibi_slopes=None, deterministic=False, return_attn_probs=False, + use_alibi=False, alibi_mode=1, imp_mode=0): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in KV must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must + be True. The calculation formula(QK^T + bias * slops_m) is shown as follows, + q1*k1 0 + q2*k1 q2*k2 -1 0 + q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m + q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0 + q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0 + + alibi_mode: int. The bias mode of ALiBi, default to 1. + alibi_mode=0: alibi_mode=1: + 0 -√0 + -1 0 -√1 -√0 + -2 -1 0 -√2 -√1 -√0 + -3 -2 -1 0 -√3 -√2 -√1 -√0 + -4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0 + -5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0 + imp_mode: int. Support two modes of backward implementation, default to 0. + imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel, + which will occupy additional memory. This mode usually has better performance. + imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block, + which will not occupy additional memory. There will be a slight performance loss in this mode. + This mode can be considered in memory-limited scenarios. + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenKVPackedFunc.apply( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, window_size, alibi_slopes, + deterministic, return_attn_probs, use_alibi, alibi_mode, imp_mode + ) + + +def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), + alibi_slopes=None, deterministic=False, return_attn_probs=False, + use_alibi=False, alibi_mode=1, imp_mode=0): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in K, V must be divisible by the number of heads in Q. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + use_alibi: bool. Whether to apply attention with linear bias(ALiBi). if True, causal must + be True. The calculation formula(QK^T + bias * slops_m) is shown as follows, + q1*k1 0 + q2*k1 q2*k2 -1 0 + q3*k1 q3*k2 q3*k3 + -2 -1 0 * slops_m + q4*k1 q4*k2 q4*k3 q4*k4 -3 -2 -1 0 + q5*k1 q5*k2 q5*k3 q5*k4 q5*k5 -4 -3 -2 -1 0 + + alibi_mode: int. The bias mode of ALiBi, default to 1. + alibi_mode=0: alibi_mode=1: + 0 -√0 + -1 0 -√1 -√0 + -2 -1 0 -√2 -√1 -√0 + -3 -2 -1 0 -√3 -√2 -√1 -√0 + -4 -3 -2 -1 0 -√4 -√3 -√2 -√1 -√0 + -5 -4 -3 -2 -1 0 -√5 -√4 -√3 -√2 -√1 -√0 + imp_mode: int. Support two modes of backward implementation, default to 0. + imp_mode=0(CUDNN_FATTN_BALANCE_MODE): Implement dQ reduction calculation by launching a new kernel, + which will occupy additional memory. This mode usually has better performance. + imp_mode=1(CUDNN_FATTN_LEAST_MEM_MODE): Implement dQ reduction calculation within the block, + which will not occupy additional memory. There will be a slight performance loss in this mode. + This mode can be considered in memory-limited scenarios. + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenFunc.apply( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, window_size, alibi_slopes, deterministic, + return_attn_probs, use_alibi, alibi_mode, imp_mode + ) diff --git a/pkgs/xformers/_flash_attn/flash_attn_triton.py b/pkgs/xformers/_flash_attn/flash_attn_triton.py new file mode 100644 index 0000000..78b7588 --- /dev/null +++ b/pkgs/xformers/_flash_attn/flash_attn_triton.py @@ -0,0 +1,832 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math + +import torch + +import triton +import triton.language as tl + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, K, V, Bias, Out, + Lse, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_ob, stride_oh, stride_om, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (offs_m[:, None] * stride_bm + offs_n[None, :]) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + k = tl.load(k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != 'none': + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load(b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn, mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0) + else: + v = tl.load(v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (offs_m[:, None] * stride_om + offs_d[None, :]) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store(out_ptrs, acc_o, + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, DO, Delta, + stride_ob, stride_oh, stride_om, + stride_dob, stride_doh, stride_dom, + nheads, seqlen_q, seqlen_q_rounded, headdim, + BLOCK_M: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load(Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + do = tl.load(DO + off_b * stride_dob + off_h * stride_doh + offs_m[:, None] * stride_dom + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == 'vector': + b_ptrs = Bias + offs_n + elif BIAS_TYPE == 'matrix': + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load(k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + v = tl.load(v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load(q_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != 'none': + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == 'vector': + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == 'matrix': + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_n[None, :] < seqlen_k), + other=0.0).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == 'none': + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + & (offs_d[None, :] < headdim), other=0.0) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not (EVEN_M & EVEN_HEADDIM): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load(dq_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0, + eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last") + else: + dq = tl.load(dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last") + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add(dq_ptrs, dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim)) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == 'matrix': + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv(dk_ptrs, dv_ptrs, dk, dv, offs_n, offs_d, seqlen_k, headdim, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qb, stride_qh, stride_qm, + stride_kb, stride_kh, stride_kn, + stride_vb, stride_vh, stride_vn, + stride_bb, stride_bh, stride_bm, + stride_dob, stride_doh, stride_dom, + stride_dqb, stride_dqh, stride_dqm, + stride_dkb, stride_dkh, stride_dkn, + stride_dvb, stride_dvh, stride_dvn, + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim, + CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != 'none': + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, K, V, Bias, + DO, DQ, DK, DV, + LSE, D, + softmax_scale, + stride_qm, stride_kn, stride_vn, stride_bm, + stride_dom, stride_dqm, stride_dkn, stride_dvn, + seqlen_q, seqlen_k, headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, EVEN_N=EVEN_N, EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + + +def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, 'FlashAttention only support head dimensions up to 128' + assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' + assert q.dtype in [torch.float16, torch.bfloat16], 'Only support fp16 and bf16' + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, k, v, bias, o, + lse, tmp, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + o.stride(0), o.stride(2), o.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, do, delta, + o.stride(0), o.stride(2), o.stride(1), + do.stride(0), do.stride(2), do.stride(1), + nheads, seqlen_q, seqlen_q_rounded, d, + BLOCK_M=128, BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = 'none' + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = 'vector' + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = 'matrix' + else: + raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' + ' or (seqlen_q, seqlen_k)') + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: (triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads) + _bwd_kernel[grid]( + q, k, v, bias, + do, dq_accum, dk, dv, + lse, delta, + softmax_scale, + q.stride(0), q.stride(2), q.stride(1), + k.stride(0), k.stride(2), k.stride(1), + v.stride(0), v.stride(2), v.stride(1), + *bias_strides, + do.stride(0), do.stride(2), do.stride(1), + dq_accum.stride(0), dq_accum.stride(2), dq_accum.stride(1), + dk.stride(0), dk.stride(2), dk.stride(1), + dv.stride(0), dv.stride(2), dv.stride(1), + nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d, + seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, causal, BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): + """ + qkv: (batch, seqlen, 3, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). + ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) + """ + # Make sure that the last dimension is contiguous + if qkv.stride(-1) != 1: + qkv = qkv.contiguous() + o, lse, ctx.softmax_scale = _flash_attn_forward( + qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], bias=bias, causal=causal, + softmax_scale=softmax_scale + ) + ctx.save_for_backward(qkv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + qkv, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[1], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dqkv = torch.empty_like(qkv) + _flash_attn_backward(do, qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], o, lse, + dqkv[:, :, 0], dqkv[:, :, 1], dqkv[:, :, 2], + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dqkv, None, None, None + + +flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): + """ + q: (batch, seqlen_q, nheads, headdim) + kv: (batch, seqlen_k, 2, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, kv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, kv, o, lse, bias = ctx.saved_tensors + if len(ctx.needs_input_grad) >= 3: + assert not ctx.needs_input_grad[2], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward(do, q, kv[:, :, 0], kv[:, :, 1], o, lse, + dq, dkv[:, :, 0], dkv[:, :, 1], + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dkv, None, None, None + + +flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply + + +class FlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): + """ + q: (batch_size, seqlen_q, nheads, headdim) + k, v: (batch_size, seqlen_k, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, k, v, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], 'FlashAttention does not support bias gradient yet' + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, + bias=bias, causal=ctx.causal, softmax_scale=ctx.softmax_scale) + return dq, dk, dv, None, None, None + + +flash_attn_func = FlashAttnFunc.apply diff --git a/pkgs/xformers/_flash_attn/flash_attn_triton_og.py b/pkgs/xformers/_flash_attn/flash_attn_triton_og.py new file mode 100644 index 0000000..fb165d5 --- /dev/null +++ b/pkgs/xformers/_flash_attn/flash_attn_triton_og.py @@ -0,0 +1,276 @@ +# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py +# for benchmarking. +# We fixed a few dtype cast to make it work for bf16 + +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +import pytest +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + TMP, L, M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +@triton.jit +def _bwd_preprocess( + Out, DO, L, + NewDO, Delta, + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, K, V, sm_scale, Out, DO, + DQ, DK, DV, + L, M, + D, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + Z, H, N_CTX, + num_block, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(q.dtype), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(k.dtype), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, k, v, sm_scale, + tmp, L, m, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK, BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, num_warps=num_warps, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( + o, do, l, + do_scaled, delta, + BLOCK_M=ctx.BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + + # NOTE: kernel currently buggy for other values of `num_warps` + num_warps = 8 + _bwd_kernel[(ctx.grid[1],)]( + q, k, v, ctx.sm_scale, + o, do_scaled, + dq, dk, dv, + l, m, + delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + q.shape[0], q.shape[1], q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=num_warps, + num_stages=1, + ) + return dq.to(q.dtype), dk, dv, None + + +attention = _attention.apply diff --git a/pkgs/xformers/_flash_attn/flash_blocksparse_attention.py b/pkgs/xformers/_flash_attn/flash_blocksparse_attention.py new file mode 100644 index 0000000..70d15fb --- /dev/null +++ b/pkgs/xformers/_flash_attn/flash_blocksparse_attention.py @@ -0,0 +1,136 @@ +import math +import torch +import torch.nn as nn + +from einops import rearrange + +import hydra + +from flash_attn.flash_blocksparse_attn_interface import flash_blocksparse_attn_func +from flash_attn.flash_blocksparse_attn_interface import convert_blockmask +from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis + + +class FlashBlocksparseAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + """ + def __init__(self, sparsity_config, softmax_temp=None, attention_dropout=0.0, + max_seq_length=2048, device=None, dtype=None): + super().__init__() + self.sparsity_config = hydra.utils.instantiate(sparsity_config) + self.softmax_temp = softmax_temp + self.dropout_p = attention_dropout + + # initialize sparse layout and register as buffer + max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256 + layout = self.sparsity_config.make_layout(max_seq_length) + self.register_buffer("layout", layout) + blockmask_converted = convert_blockmask(self.layout, causal=False) + self.register_buffer("blockmask_converted", blockmask_converted) + # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}') + + def forward(self, qkv, attn_mask=None, key_padding_mask=None, causal=False, cu_seqlens=None, + max_s=None, need_weights=False, convert_mask=True): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + key_padding_mask: An implementation of BaseMask that encodes how + many query each sequence in the batch consists of + """ + assert not need_weights + assert attn_mask is None + assert qkv.dtype == torch.float16 + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + # Convert mask to take a subset + seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 + assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1] + blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256] + if key_padding_mask is None: + qkv = rearrange(qkv, 'b s ... -> (b s) ...') + max_s = seqlen + cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=qkv.device) + output = flash_blocksparse_attn_func( + qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, + max_s, softmax_scale=self.softmax_temp, causal=causal + ) + output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) + else: + key_padding_mask_bool = key_padding_mask.bool_matrix + nheads = qkv.shape[-2] + x = rearrange(qkv, 'b s three h d -> b s (three h d)') + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) + x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) + output_unpad = flash_blocksparse_attn_func( + x_unpad, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, + max_s, softmax_scale=self.softmax_temp, causal=causal + ) + output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), + indices, batch_size, seqlen), + 'b s (h d) -> b s h d', h=nheads) + else: + assert max_s is not None + seqlen = max_s + # Convert mask to take a subset + seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 + assert seqlen_rounded // 16 <= self.layout.shape[0], seqlen_rounded // 256 <= self.layout.shape[1] + blockmask = self.layout[:seqlen_rounded // 16, :seqlen_rounded // 256] + if convert_mask: + output = flash_blocksparse_attn_func( + qkv, cu_seqlens, blockmask, self.dropout_p if self.training else 0.0, + max_s, softmax_scale=self.softmax_temp, causal=causal + ) + else: + output = flash_blocksparse_attn_func( + qkv, cu_seqlens, self.blockmask_converted, self.dropout_p if self.training else 0.0, + max_s, softmax_scale=self.softmax_temp, causal=causal, + convert_mask=False, + ) + + return output, None + + +class FlashBlocksparseMHA(nn.Module): + + def __init__(self, embed_dim, num_heads, sparsity_config, bias=True, batch_first=True, + attention_dropout=0.0, causal=False, max_seq_length=2048, + device=None, dtype=None, **kwargs) -> None: + assert batch_first + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" + + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.inner_attn = FlashBlocksparseAttention( + sparsity_config, attention_dropout=attention_dropout, + max_seq_length=max_seq_length, **factory_kwargs + ) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + def forward(self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, + need_weights=False): + qkv = self.Wqkv(x) + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) + context, attn_weights = self.inner_attn(qkv, key_padding_mask=key_padding_mask, + need_weights=need_weights, causal=self.causal) + return self.out_proj(rearrange(context, 'b s h d -> b s (h d)')), attn_weights diff --git a/pkgs/xformers/_flash_attn/flash_blocksparse_attn_interface.py b/pkgs/xformers/_flash_attn/flash_blocksparse_attn_interface.py new file mode 100644 index 0000000..34a111a --- /dev/null +++ b/pkgs/xformers/_flash_attn/flash_blocksparse_attn_interface.py @@ -0,0 +1,142 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py +import torch +import torch.nn as nn + +import flash_attn_cuda + + +def convert_blockmask(blockmask, causal): + """Convert from the 0-1 format to the format used by the CUDA code. + 0 means the block is skipped. + nonzero means the block is not skipped. + Argument: + blockmask: (row, col): a 0-1 tensor + Return: + blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row + indices of the nonzero blocks, padded with -1 to reach length @row. + The indices are multiplied by 4, with the smallest bit used to encode whether + it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is + the last nonzero in its row.. + """ + assert not causal + # TD [2022-05-13]: The indexing and sorting is very tricky + nrow, ncol = blockmask.shape + # Sort does not support bool on CUDA + blockmask = blockmask.to(dtype=torch.uint8) + nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True) + nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0) + last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1] + last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ + torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row + ] + first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0] + first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ + torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row + ] + nonzero_idx = nonzero_sorted_rowidx * 4 + nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2 + nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1 + nonzero_idx[nonzero_val == 0] = -1 + return nonzero_idx.T.contiguous().to(dtype=torch.int32) + + +def _flash_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, + causal, return_softmax): + context, softmax_lse, *rest = flash_attn_cuda.fwd_block(qkv, cu_seqlens, blockmask, dropout_p, + max_s, softmax_scale, causal, + return_softmax, None) + # if context.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + S_dmask = rest[0] if return_softmax else None + return context, softmax_lse, S_dmask + + +def _flash_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, blockmask, + dropout_p, max_s, softmax_scale, causal): + dqkv, dp, softmax_d = flash_attn_cuda.bwd_block(dout, qkv, out, S_dmask, softmax_lse, cu_seqlens, + blockmask, dropout_p, softmax_scale, max_s, + causal, None) + # if dqkv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dqkv + + +class FlashBlocksparseAttnFun(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( + qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, + return_softmax=False + ) + ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) + ctx.dropout_p = dropout_p + ctx.max_s = max_s + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return context + + @staticmethod + def backward(ctx, dout): + qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + # S_dmask is None, temporarily use another tensor just to get it running + dqkv = _flash_blocksparse_attn_backward( + dout, qkv, context, context, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, + ctx.max_s, ctx.softmax_scale, ctx.causal + ) + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None, None + + +# We duplicate code to return both the output and the softmax for testing +# Returning both makes backward a bit slower, so we want to keep using the other version for speed. +class FlashBlocksparseAttnFunWithS(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): + # Save rng_state because the backward pass is gonna regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( + qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal=causal, + return_softmax=True + ) + ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) + ctx.dropout_p = dropout_p + ctx.max_s = max_s + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return context, S_dmask, softmax_lse + + @staticmethod + def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): + qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dqkv = _flash_blocksparse_attn_backward( + dout, qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, ctx.dropout_p, + ctx.max_s, ctx.softmax_scale, ctx.causal + ) + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None + + +def flash_blocksparse_attn_func(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale=None, + causal=False, return_attn_probs=False, convert_mask=True): + """dropout_p should be set to 0.0 during evaluation + """ + func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS + if convert_mask: + blockmask = convert_blockmask(blockmask, causal=causal) + return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal) diff --git a/pkgs/xformers/_flash_attn/fused_softmax.py b/pkgs/xformers/_flash_attn/fused_softmax.py new file mode 100644 index 0000000..5335481 --- /dev/null +++ b/pkgs/xformers/_flash_attn/fused_softmax.py @@ -0,0 +1,205 @@ +# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py +# for benchmarking. +# We added support for seqlen=2k and seqlen=4k + +# coding=utf-8 +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from apex._autocast_utils import _cast_if_autocast_enabled +from apex.transformer.enums import AttnMaskType + +from fused_softmax_lib import scaled_masked_softmax_forward, scaled_masked_softmax_backward +from fused_softmax_lib import scaled_masked_softmax_get_batch_per_block +from fused_softmax_lib import scaled_upper_triang_masked_softmax_forward, scaled_upper_triang_masked_softmax_backward + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_forward( + inputs, scale_t[0] + ) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None + + +def scaled_upper_triang_masked_softmax(inputs, _, scale): + b, np, sq, sk = inputs.size() + assert sq == sk, "causal mask is only for self attention" + # Reshaping input to 3D tensor (attn_batches, sq, sk) + inputs = inputs.view(-1, sq, sk) + args = _cast_if_autocast_enabled(inputs, scale) + with torch.cuda.amp.autocast(enabled=False): + probs = ScaledUpperTriangMaskedSoftmax.apply(*args) + return probs.view(b, np, sq, sk) + + +# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. +# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. +# So I needed to manually write two `torch.autograd.Function` inheritances. +# Fused operation which performs following three operations in sequence +# 1. Scale the tensor. +# 2. Apply the mask. +# 3. Perform softmax. +class ScaledMaskedSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_masked_softmax_backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None, None + + +def scaled_masked_softmax(inputs, mask, scale): + # input is 4D tensor (b, np, sq, sk) + args = _cast_if_autocast_enabled(inputs, mask, scale) + with torch.cuda.amp.autocast(enabled=False): + return ScaledMaskedSoftmax.apply(*args) + + +class FusedScaleMaskSoftmax(torch.nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super().__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + if self.input_in_fp16 and self.input_in_bf16: + raise RuntimeError( + "both fp16 and bf16 flags cannot be active at the same time." + ) + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + if not (self.scale is None or softmax_in_fp32): + raise RuntimeError("softmax should be in fp32 when scaled") + + if self.scaled_masked_softmax_fusion: + if self.attn_mask_type == AttnMaskType.causal: + self.fused_softmax_func = scaled_upper_triang_masked_softmax + elif self.attn_mask_type == AttnMaskType.padding: + self.fused_softmax_func = scaled_masked_softmax + else: + raise ValueError("Invalid attn_mask_type.") + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and ( + self.attn_mask_type == AttnMaskType.causal + or (self.attn_mask_type == AttnMaskType.padding and mask is not None) + ) + and 16 < sk <= 8192 # sk must be 16 ~ 8192 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 8192: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + # input.shape = [b, np, sq, sk] + scale = self.scale if self.scale is not None else 1.0 + return self.fused_softmax_func(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) diff --git a/pkgs/xformers/_flash_attn/layers/__init__.py b/pkgs/xformers/_flash_attn/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers/_flash_attn/layers/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..5afba24 Binary files /dev/null and b/pkgs/xformers/_flash_attn/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-310.pyc b/pkgs/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-310.pyc new file mode 100644 index 0000000..14c0fb8 Binary files /dev/null and b/pkgs/xformers/_flash_attn/layers/__pycache__/patch_embed.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/layers/__pycache__/rotary.cpython-310.pyc b/pkgs/xformers/_flash_attn/layers/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000..e7eff5c Binary files /dev/null and b/pkgs/xformers/_flash_attn/layers/__pycache__/rotary.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/layers/patch_embed.py b/pkgs/xformers/_flash_attn/layers/patch_embed.py new file mode 100644 index 0000000..a68dd6d --- /dev/null +++ b/pkgs/xformers/_flash_attn/layers/patch_embed.py @@ -0,0 +1,56 @@ +# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py +# But we use nn.Linear instead of Conv2d and it's about 8x faster. + +from functools import partial + +import torch.nn as nn +from torch import _assert +from torch.nn.modules.utils import _pair + +from einops import rearrange + +try: + from flash_attn.ops.fused_dense import FusedDense +except ImportError: + FusedDense = None + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + fused_bias_fc=False, + ): + super().__init__() + img_size = _pair(img_size) + patch_size = _pair(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + if fused_bias_fc and FusedDense is None: + raise ImportError('fused_dense is not installed') + + linear_cls = nn.Linear if not fused_bias_fc or not bias else FusedDense + self.proj = linear_cls(in_chans * patch_size[0] * patch_size[1], embed_dim, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + _, _, H, W = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(rearrange(x, 'b c (h p1) (w p2) -> b h w (c p1 p2)', + p1=self.patch_size[0], p2=self.patch_size[1])) + if self.flatten: + x = rearrange(x, 'b h w c -> b (h w) c') + x = self.norm(x) + return x diff --git a/pkgs/xformers/_flash_attn/layers/rotary.py b/pkgs/xformers/_flash_attn/layers/rotary.py new file mode 100644 index 0000000..f009440 --- /dev/null +++ b/pkgs/xformers/_flash_attn/layers/rotary.py @@ -0,0 +1,336 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Tuple, Optional +import math + +import torch + +from einops import rearrange, repeat + +import rotary_emb + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), '... d two -> ... (d two)', two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, 's d -> s 1 (2 d)') + sin = repeat(sin, 's d -> s 1 (2 d)') + return torch.cat([x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:]], dim=-1) + + +class ApplyRotaryEmb(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, cos, sin, interleaved=False, inplace=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + batch, seqlen, nheads, headdim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + x_ro = x[..., :rotary_dim] + x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) + out = torch.empty_like(x) if not inplace else x + out_ro = out[..., :rotary_dim] + if inplace: + o1, o2 = x1, x2 + else: + o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved + else (out_ro[..., ::2], out_ro[..., 1::2])) + rotary_emb.apply_rotary(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), o1, o2, False) + if not inplace and rotary_dim < headdim: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.inplace = inplace + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + _, seqlen, _, headdim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + inplace = ctx.inplace + do_ro = do[..., :rotary_dim] + do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved + else (do_ro[..., ::2], do_ro[..., 1::2])) + dx = torch.empty_like(do) if not inplace else do + if inplace: + dx1, dx2 = do1, do2 + else: + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dx_ro[..., ::2], dx_ro[..., 1::2])) + rotary_emb.apply_rotary(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), dx1, dx2, True) + if not inplace and rotary_dim < headdim: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + return dx, None, None, None, None + + +apply_rotary_emb_func = ApplyRotaryEmb.apply + + +class ApplyRotaryEmbQKV_(torch.autograd.Function): + + @staticmethod + def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): + """ + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of q and k. + """ + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) + q_ro = qkv[:, :, 0, :, :rotary_dim] + q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) + rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False) + k_ro = qkv[:, :, 1, :, :rotary_dim] + k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) + rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), + rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + cos, sin, cos_k, sin_k = ctx.saved_tensors + _, seqlen, _, _, headdim = dqkv.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + dq_ro = dqkv[:, :, 0, :, :rotary_dim] + dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dq_ro[..., ::2], dq_ro[..., 1::2])) + rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True) + dk_ro = dqkv[:, :, 1, :, :rotary_dim] + dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dk_ro[..., ::2], dk_ro[..., 1::2])) + rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'), + rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True) + return dqkv, None, None, None, None, None + + +apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply + + +class ApplyRotaryEmbKV_(torch.autograd.Function): + + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False): + """ + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of k. + """ + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim + assert seqlen <= rotary_seqlen + k_ro = kv[:, :, 0, :, :rotary_dim] + k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) + rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, + False) # conj=False since this is the forward pass + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + cos, sin = ctx.saved_tensors + _, seqlen, _, _, headdim = dkv.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + dk_ro = dkv[:, :, 0, :, :rotary_dim] + dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved + else (dk_ro[..., ::2], dk_ro[..., 1::2])) + rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'), + rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, + True) # conj=True since this is the backward pass + return dkv, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV_.apply + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__(self, dim: int, base=10000.0, interleaved=False, scale_base=None, + pos_idx_in_fp32=True, device=None): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) + / (1.4 * dim) if scale_base is not None else None) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, + dtype=torch.float32) / self.dim)) + + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if (seqlen > self._seq_len_cached or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference())): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2) / self.scale_base) + scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1') + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward(self, qkv: torch.Tensor, kv: Optional[torch.Tensor] = None, + seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: can be used in generation where the qkv being passed in is only the last + token in the batch. + """ + seqlen = qkv.shape[1] + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], + None, None, self.interleaved + ) + else: + return apply_rotary_emb_qkv_( + qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], + self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], + self.interleaved + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], + self.interleaved, True + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:], + self.interleaved + ) + else: + kv = apply_rotary_emb_kv_( + kv, self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:], + self.interleaved + ) + return q, kv diff --git a/pkgs/xformers/_flash_attn/losses/__init__.py b/pkgs/xformers/_flash_attn/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers/_flash_attn/losses/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/losses/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..8f60068 Binary files /dev/null and b/pkgs/xformers/_flash_attn/losses/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-310.pyc b/pkgs/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-310.pyc new file mode 100644 index 0000000..369d582 Binary files /dev/null and b/pkgs/xformers/_flash_attn/losses/__pycache__/cross_entropy.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/losses/cross_entropy.py b/pkgs/xformers/_flash_attn/losses/cross_entropy.py new file mode 100644 index 0000000..bc2df84 --- /dev/null +++ b/pkgs/xformers/_flash_attn/losses/cross_entropy.py @@ -0,0 +1,129 @@ +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/cross_entropy.py +# But we make it much faster: we compute the local loss and the LSE, and by exchanging the LSE and +# the losses we can get the global loss. There's no need to do it step by step +# (compute local max, exchange, compute exp, compute local sum, exchange, etc.) +# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py +import torch +import torch.nn as nn + +import xentropy_cuda_lib + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 2 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base + + +class SoftmaxCrossEntropyLossFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, logits, labels, smoothing=0.0, ignored_index=-100, inplace_backward=False, + process_group=None): + """ + logits: (batch, vocab_size) + labels: (batch,) + If process_group is not None, we're doing Tensor Parallel: each process is responsible for + one part of the vocab. The loss needs to be aggregated across processes. + """ + batch, vocab_size = logits.shape + assert labels.shape == (batch,) + world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) + ctx.total_classes = world_size * vocab_size + + if world_size == 1: + losses, lse = xentropy_cuda_lib.forward(logits, labels, smoothing) + losses.masked_fill_(labels==ignored_index, 0) + labels_local = labels + else: + rank = torch.distributed.get_rank(process_group) + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + labels_mask = (labels < vocab_start_index) | (labels >= vocab_end_index) + ignored_mask = labels == ignored_index + labels_local = torch.where(ignored_mask, labels, labels - vocab_start_index) + + # For tensor parallel cross entropy with smoothing, we want to pass in the total number + # of classes so that smoothing can be applied correctly. If total_classes=-1, use the + # last dimension of the input tensor. + losses, lse_local = xentropy_cuda_lib.forward(logits, labels_local, smoothing, + world_size * vocab_size) + assert lse_local.shape == (batch,) + assert losses.shape == (batch,) + losses.masked_fill_(ignored_mask, 0) + # For labels == ignored_index, the loss is always 0. + # If there's no smoothing, if labels are in the vocab of this partition, losses contains + # lse_local - predicted logit, and 0 otherwise. + # If there's smoothing=0.1, for labels in the vocab of this partition, losses contains + # 0.9 * (lse_local - predicted logit) + 0.1 * (lse_local - sum logit / total_classes) + # For labels not in the vocab of this partition, losses contains + # 0.1 * (lse_local - sum logit / total_classes). + + lse_allgather = torch.empty(world_size, batch, dtype=lse_local.dtype, + device=lse_local.device) + torch.distributed.all_gather_into_tensor(lse_allgather, lse_local.contiguous(), + group=process_group) + handle_losses = torch.distributed.all_reduce( + losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True + ) + lse = torch.logsumexp(lse_allgather, dim=0) + # If there's no smoothing, the total losses are lse_local - predicted_logit, + # we just have to subtract the lse_local and add the lse (global). + # If there's smoothing=0.1, the total losses are + # 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes) + # We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes). + rank_per_sample = torch.div(labels, vocab_size, rounding_mode='floor') + lse_local = lse_allgather[rank_per_sample, + torch.arange(batch, device=lse_allgather.device)] + + handle_losses.wait() + if smoothing == 0.0: + losses += lse - lse_local + else: + losses += ((1 - smoothing) * (lse - lse_local) + + smoothing * (lse - lse_allgather.sum(dim=0))) + losses.masked_fill_(ignored_mask, 0) + + ctx.save_for_backward(logits, lse, labels_local) + ctx.smoothing = smoothing + ctx.ignored_index = ignored_index + ctx.inplace_backward = inplace_backward + return losses + + @staticmethod + def backward(ctx, grad_loss): + logits, lse, labels = ctx.saved_tensors + grad_loss = grad_loss.contiguous() + grad_loss.masked_fill_(labels==ctx.ignored_index, 0) + grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, lse, labels, + ctx.smoothing, ctx.inplace_backward, + ctx.total_classes) + return grad_logits, None, None, None, None, None, None + + +class CrossEntropyLoss(nn.Module): + + def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0, + inplace_backward=False, process_group=None): + super().__init__() + if reduction not in ['mean', 'none']: + raise NotImplementedError("Only support reduction = 'mean' or 'none'") + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.inplace_backward = inplace_backward + self.process_group = process_group + + def forward(self, input, target): + assert input.is_cuda and target.is_cuda + # SoftmaxCrossEntropyLoss implicitly casts to float + loss = SoftmaxCrossEntropyLossFn.apply( + input, target, self.label_smoothing, self.ignore_index, self.inplace_backward, + self.process_group + ) + if self.reduction == 'mean': + return loss.sum() / (target != self.ignore_index).sum() + else: + return loss diff --git a/pkgs/xformers/_flash_attn/models/__init__.py b/pkgs/xformers/_flash_attn/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3930b22 Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/bert.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/bert.cpython-310.pyc new file mode 100644 index 0000000..bddc6a9 Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/bert.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/falcon.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/falcon.cpython-310.pyc new file mode 100644 index 0000000..0fcfe9b Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/falcon.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc new file mode 100644 index 0000000..35407bd Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/gpt.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-310.pyc new file mode 100644 index 0000000..96f040a Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/gpt_neox.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/gptj.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/gptj.cpython-310.pyc new file mode 100644 index 0000000..e7a424d Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/gptj.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/llama.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/llama.cpython-310.pyc new file mode 100644 index 0000000..14e35e3 Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/llama.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc new file mode 100644 index 0000000..ba08f2a Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/opt.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc b/pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc new file mode 100644 index 0000000..1557001 Binary files /dev/null and b/pkgs/xformers/_flash_attn/models/__pycache__/vit.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/models/bert.py b/pkgs/xformers/_flash_attn/models/bert.py new file mode 100644 index 0000000..710c7e8 --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/bert.py @@ -0,0 +1,531 @@ +# Copyright (c) 2022, Tri Dao. +# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation. +# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py +# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py + +# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py + +import re +import logging +from functools import partial + +from collections.abc import Sequence +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import BertConfig +from transformers.models.bert.modeling_bert import BaseModelOutputWithPoolingAndCrossAttentions +from transformers.models.bert.modeling_bert import BertForPreTrainingOutput + +from einops import rearrange + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp, FusedMLP +from flash_attn.modules.block import Block +from flash_attn.modules.embedding import BertEmbeddings +from flash_attn.bert_padding import unpad_input, pad_input +from flash_attn.bert_padding import index_first_axis, index_first_axis_residual +from flash_attn.utils.pretrained import state_dict_from_pretrained + +try: + from flash_attn.ops.fused_dense import FusedDense +except ImportError: + FusedDense = None + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm +except ImportError: + dropout_add_layer_norm, layer_norm = None, None + +try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss +except ImportError: + CrossEntropyLoss = None + + +logger = logging.getLogger(__name__) + + +def create_mixer_cls(config, cross_attn=False, return_residual=False): + use_flash_attn = getattr(config, 'use_flash_attn', False) + fused_bias_fc = getattr(config, 'fused_bias_fc', False) + rotary_kwargs = {} + if config.position_embedding_type == "rotary": + rotary_kwargs["rotary_emb_dim"] = getattr(config, "rotary_emb_dim", config.hidden_size) + rotary_kwargs["rotary_emb_base"] = getattr(config, "rotary_emb_base", 10000.0) + rotary_kwargs["rotary_emb_scale_base"] = getattr(config, "rotary_emb_scale_base", None) + rotary_kwargs["rotary_emb_interleaved"] = getattr(config, "rotary_emb_interleaved", False) + mixer_cls = partial(MHA, num_heads=config.num_attention_heads, cross_attn=cross_attn, + dropout=config.attention_probs_dropout_prob, causal=False, + fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn, + return_residual=return_residual, **rotary_kwargs) + return mixer_cls + + +def create_mlp_cls(config, layer_idx=None, return_residual=False): + inner_dim = config.intermediate_size + fused_mlp = getattr(config, 'fused_mlp', False) + if fused_mlp: + assert config.hidden_act in ['gelu_new', 'gelu_fast'], ('fused_mlp only ' + 'supports approximate gelu') + if not fused_mlp: + approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' + mlp_cls = partial(Mlp, hidden_features=inner_dim, + activation=partial(F.gelu, approximate=approximate), + return_residual=return_residual) + else: + if FusedMLP is None: + raise ImportError('fused_dense is not installed') + mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) + # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer + if isinstance(mlp_checkpoint_lvl, Sequence): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + mlp_cls = partial(FusedMLP, hidden_features=inner_dim, + checkpoint_lvl=mlp_checkpoint_lvl, return_residual=return_residual) + return mlp_cls + + +def create_block(config, layer_idx=None): + last_layer_subset = getattr(config, 'last_layer_subset', False) + cross_attn=last_layer_subset and layer_idx == config.num_hidden_layers - 1 + # TD [2022-12-19]: For cross attention (last layer), we actually want to return the + # residual x_kv, not residual x. But it's annoying to change the API (and it only affects + # one layer) so we just choose not to return residual in this case. + return_residual = not cross_attn + mixer_cls = create_mixer_cls(config, cross_attn, return_residual=return_residual) + mlp_cls = create_mlp_cls(config, layer_idx, return_residual=return_residual) + norm_cls = partial(nn.LayerNorm, eps=config.layer_norm_eps) + block = Block(config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, + prenorm=False, resid_dropout1=config.hidden_dropout_prob, + resid_dropout2=config.hidden_dropout_prob, + fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + return_residual=return_residual) + return block + + +# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748 +def _init_weights(module, initializer_range=0.02): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + if module.padding_idx is not None: + nn.init.zeros_(module.weight[module.padding_idx]) + + +class BertEncoder(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.use_flash_attn = getattr(config, 'use_flash_attn', False) + self.layers = nn.ModuleList([create_block(config, layer_idx=i) + for i in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, key_padding_mask=None, subset_mask=None): + """If subset_mask is not None, we only want output for the subset of the sequence. + This means that we only compute the last layer output for these tokens. + subset_mask: (batch, seqlen), dtype=torch.bool + """ + if key_padding_mask is None or not self.use_flash_attn: + mixer_kwargs = ({'key_padding_mask': key_padding_mask} + if key_padding_mask is not None else None) + for layer in self.layers: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if subset_mask is not None: + hidden_states = hidden_states[subset_mask] + else: + batch, seqlen = hidden_states.shape[:2] + hidden_states, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( + hidden_states, key_padding_mask + ) + mixer_kwargs = {'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen_in_batch} + if subset_mask is None: + for layer in self.layers: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + hidden_states = pad_input(hidden_states, indices, batch, seqlen) + else: + for layer in self.layers[:-1]: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if key_padding_mask is not None: + subset_idx = torch.nonzero(subset_mask[key_padding_mask], as_tuple=False).flatten() + subset_seqlens = (subset_mask & key_padding_mask).sum(dim=-1, dtype=torch.int32) + subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0, + dtype=torch.torch.int32), (1, 0)) + else: + subset_idx = torch.nonzero(subset_mask, as_tuple=False).flatten() + subset_seqlens = subset_mask.sum(dim=-1, dtype=torch.int32) + subset_cu_seqlens = F.pad(torch.cumsum(subset_seqlens, dim=0, + dtype=torch.torch.int32), (1, 0)) + hidden_states_subset, hidden_states = index_first_axis_residual( + hidden_states, subset_idx + ) + # It's ok to set max_seqlen_q to be much larger + mixer_kwargs = {'x_kv': hidden_states, + 'cu_seqlens': subset_cu_seqlens, 'max_seqlen': max_seqlen_in_batch, + 'cu_seqlens_k': cu_seqlens, 'max_seqlen_k': max_seqlen_in_batch} + hidden_states = self.layers[-1](hidden_states_subset, mixer_kwargs=mixer_kwargs) + return hidden_states + + +class BertPooler(nn.Module): + + def __init__(self, config): + super().__init__() + fused_bias_fc = getattr(config, 'fused_bias_fc', False) + if fused_bias_fc and FusedDense is None: + raise ImportError('fused_dense is not installed') + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + self.dense = linear_cls(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states, pool=True): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] if pool else hidden_states + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + + def __init__(self, config): + super().__init__() + fused_bias_fc = getattr(config, 'fused_bias_fc', False) + if fused_bias_fc and FusedDense is None: + raise ImportError('fused_dense is not installed') + self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) + if self.fused_dropout_add_ln and layer_norm is None: + raise ImportError('dropout_add_layer_norm is not installed') + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + self.dense = linear_cls(config.hidden_size, config.hidden_size) + approximate = 'tanh' if config.hidden_act in ['gelu_new', 'gelu_fast'] else 'none' + self.transform_act_fn = nn.GELU(approximate=approximate) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + if not self.fused_dropout_add_ln: + hidden_states = self.layer_norm(hidden_states) + else: + hidden_states = layer_norm(hidden_states, self.layer_norm.weight, self.layer_norm.bias, + self.layer_norm.eps) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + + def __init__(self, config): + super().__init__() + fused_bias_fc = getattr(config, 'fused_bias_fc', False) + if fused_bias_fc and FusedDense is None: + raise ImportError('fused_dense is not installed') + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = linear_cls(config.hidden_size, config.vocab_size, bias=True) + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertPreTrainingHeads(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + self.seq_relationship = nn.Linear(config.hidden_size, 2) + + def forward(self, sequence_output, pooled_output): + prediction_scores = self.predictions(sequence_output) + seq_relationship_score = self.seq_relationship(pooled_output) + return prediction_scores, seq_relationship_score + + +class BertPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, BertConfig): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + @classmethod + def from_pretrained(cls, model_name, config, *inputs, **kwargs): + """ + Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + + Params: + pretrained_model_name_or_path: either: + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance + - a path or url to a pretrained model archive containing: + . `bert_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + *inputs, **kwargs: additional input for the specific Bert class + (ex: num_labels for BertForSequenceClassification) + """ + # Instantiate model. + model = cls(config, *inputs, **kwargs) + load_return = model.load_state_dict(remap_state_dict(state_dict_from_pretrained(model_name), + config), strict=False) + logger.info(load_return) + return model + + +class BertModel(BertPreTrainedModel): + + def __init__(self, config: BertConfig, add_pooling_layer=True): + super().__init__(config) + self.pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + if config.vocab_size % self.pad_vocab_size_multiple != 0: + config.vocab_size += (self.pad_vocab_size_multiple + - (config.vocab_size % self.pad_vocab_size_multiple)) + self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) + if self.fused_dropout_add_ln and layer_norm is None: + raise ImportError('dropout_add_layer_norm is not installed') + assert config.hidden_act in ['gelu', 'gelu_new', 'gelu_fast'] + + self.embeddings = BertEmbeddings(config.hidden_size, config.vocab_size, + config.max_position_embeddings, config.type_vocab_size, + padding_idx=config.pad_token_id) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.encoder = BertEncoder(config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + + def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, + masked_tokens_mask=None): + """If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining), + we only want the output for the masked tokens. This means that we only compute the last + layer output for these tokens. + masked_tokens_mask: (batch, seqlen), dtype=torch.bool + """ + hidden_states = self.embeddings(input_ids, position_ids=position_ids, + token_type_ids=token_type_ids) + # TD [2022-12:18]: Don't need to force residual in fp32 + # BERT puts embedding LayerNorm before embedding dropout. + if not self.fused_dropout_add_ln: + hidden_states = self.emb_ln(hidden_states) + else: + hidden_states = layer_norm(hidden_states, self.emb_ln.weight, self.emb_ln.bias, + self.emb_ln.eps) + hidden_states = self.emb_drop(hidden_states) + + if masked_tokens_mask is not None: + batch_size, seqlen = input_ids.shape[:2] + # We also need the first column for the CLS token + first_col_mask = torch.zeros(batch_size, seqlen, dtype=torch.bool, + device=input_ids.device) + first_col_mask[:, 0] = True + subset_mask = masked_tokens_mask | first_col_mask + else: + subset_mask = None + + sequence_output = self.encoder(hidden_states, key_padding_mask=attention_mask, + subset_mask=subset_mask) + + if masked_tokens_mask is None: + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + else: + # TD [2022-03-01]: the indexing here is very tricky. + if attention_mask is not None: + subset_idx = subset_mask[attention_mask] + pool_input = sequence_output[first_col_mask[attention_mask][subset_idx]] + sequence_output = sequence_output[masked_tokens_mask[attention_mask][subset_idx]] + else: + pool_input = sequence_output[first_col_mask[subset_mask]] + sequence_output = sequence_output[masked_tokens_mask[subset_mask]] + pooled_output = (self.pooler(pool_input, pool=False) + if self.pooler is not None else None) + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + ) + + +class BertForPreTraining(BertPreTrainedModel): + + def __init__(self, config: BertConfig): + super().__init__(config) + # If dense_seq_output, we only need to pass the hidden states for the masked out tokens + # (around 15%) to the classifier heads. + self.dense_seq_output = getattr(config, 'dense_seq_output', False) + # If last_layer_subset, we only need the compute the last layer for a subset of tokens + # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction). + self.last_layer_subset = getattr(config, 'last_layer_subset', False) + if self.last_layer_subset: + assert self.dense_seq_output, 'last_layer_subset requires dense_seq_output' + use_xentropy = getattr(config, 'use_xentropy', False) + if use_xentropy and CrossEntropyLoss is None: + raise ImportError('xentropy_cuda is not installed') + loss_cls = (nn.CrossEntropyLoss if not use_xentropy + else partial(CrossEntropyLoss, inplace_backward=True)) + + self.bert = BertModel(config) + self.cls = BertPreTrainingHeads(config) + self.mlm_loss = loss_cls(ignore_index=0) + self.nsp_loss = loss_cls(ignore_index=-1) + + # Initialize weights and apply final processing + self.apply(partial(_init_weights, initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): + self.cls.predictions.decoder.weight = self.bert.embeddings.word_embeddings.weight + + def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, + labels=None, next_sentence_label=None): + """ + If labels are provided, they must be 0 for masked out tokens (as specified in the attention + mask). + Outputs: + if `labels` and `next_sentence_label` are not `None`: + Outputs the total_loss which is the sum of the masked language modeling loss and the next + sentence classification loss. + if `labels` or `next_sentence_label` is `None`: + Outputs a tuple comprising + - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and + - the next sentence classification logits of shape [batch_size, 2]. + + """ + masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None + outputs = self.bert( + input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask.bool() if attention_mask is not None else None, + masked_tokens_mask=masked_tokens_mask + ) + sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output + if self.dense_seq_output and labels is not None: + masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten() + if not self.last_layer_subset: + sequence_output = index_first_axis(rearrange(sequence_output, 'b s d -> (b s) d'), + masked_token_idx) + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + if self.dense_seq_output and labels is not None: # prediction_scores are already flattened + masked_lm_loss = self.mlm_loss(prediction_scores, + labels.flatten()[masked_token_idx]) + else: + masked_lm_loss = self.mlm_loss(rearrange(prediction_scores, '... v -> (...) v'), + rearrange(labels, '... -> (...)')) + next_sentence_loss = self.nsp_loss(rearrange(seq_relationship_score, '... t -> (...) t'), + rearrange(next_sentence_label, '... -> (...)')) + total_loss = masked_lm_loss.float() + next_sentence_loss.float() + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + ) + + +def remap_state_dict(state_dict, config): + # LayerNorm + def key_mapping_ln_gamma_beta(key): + key = re.sub(r'LayerNorm.gamma$', 'LayerNorm.weight', key) + key = re.sub(r'LayerNorm.beta$', 'LayerNorm.bias', key) + return key + state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items()) + + # Layers + def key_mapping_layers(key): + return re.sub(r'^bert.encoder.layer.', 'bert.encoder.layers.', key) + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^bert.embeddings.LayerNorm.', 'bert.emb_ln.', key) + key = re.sub(r'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)', + r'bert.encoder.layers.\1.norm1.\2', key) + key = re.sub(r'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)', + r'bert.encoder.layers.\1.norm2.\2', key) + key = re.sub(r'^cls.predictions.transform.LayerNorm.(weight|bias)', + r'cls.predictions.transform.layer_norm.\1', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)', + r'bert.encoder.layers.\1.mlp.fc1.\2', key) + key = re.sub(r'^bert.encoder.layers.(\d+).output.dense.(weight|bias)', + r'bert.encoder.layers.\1.mlp.fc2.\2', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + last_layer_subset = getattr(config, 'last_layer_subset', False) + for d in range(config.num_hidden_layers): + Wq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.weight') + Wk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.weight') + Wv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.weight') + bq = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.query.bias') + bk = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.key.bias') + bv = state_dict.pop(f'bert.encoder.layers.{d}.attention.self.value.bias') + if not (last_layer_subset and d == config.num_hidden_layers - 1): + state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.weight'] = torch.cat( + [Wq, Wk, Wv], dim=0 + ) + state_dict[f'bert.encoder.layers.{d}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0) + else: + state_dict[f'bert.encoder.layers.{d}.mixer.Wq.weight'] = Wq + state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.weight'] = torch.cat( + [Wk, Wv], dim=0 + ) + state_dict[f'bert.encoder.layers.{d}.mixer.Wq.bias'] = bq + state_dict[f'bert.encoder.layers.{d}.mixer.Wkv.bias'] = torch.cat([bk, bv], dim=0) + def key_mapping_attn(key): + return re.sub(r'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)', + r'bert.encoder.layers.\1.mixer.out_proj.\2', key) + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + def key_mapping_decoder_bias(key): + return re.sub(r'^cls.predictions.bias', 'cls.predictions.decoder.bias', key) + state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items()) + + # Word embedding + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + if pad_vocab_size_multiple > 1: + word_embeddings = state_dict['bert.embeddings.word_embeddings.weight'] + state_dict['bert.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0]) + ) + decoder_weight = state_dict['cls.predictions.decoder.weight'] + state_dict['cls.predictions.decoder.weight'] = F.pad( + decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0]) + ) + # If the vocab was padded, we want to set the decoder bias for those padded indices to be + # strongly negative (i.e. the decoder shouldn't predict those indices). + # TD [2022-05-09]: I don't think it affects the MLPerf training. + decoder_bias = state_dict['cls.predictions.decoder.bias'] + state_dict['cls.predictions.decoder.bias'] = F.pad( + decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0 + ) + + return state_dict diff --git a/pkgs/xformers/_flash_attn/models/falcon.py b/pkgs/xformers/_flash_attn/models/falcon.py new file mode 100644 index 0000000..86768eb --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/falcon.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from einops import rearrange + +from transformers import GPT2Config, FalconConfig + + +def remap_state_dict_hf_falcon(state_dict, config): + def key_mapping_layers(key): + return re.sub(r'^transformer.h.', 'transformer.layers.', key) + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r'^transformer.word_embeddings.', 'transformer.embeddings.word_embeddings.', key) + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, 'tie_word_embeddings'): + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + else: + output_embeddings = state_dict.pop('lm_head.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict['lm_head.weight'] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + output_embeddings_bias = state_dict.pop('lm_head.bias') + state_dict['lm_head.bias'] = F.pad( + output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', + r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', + r'transformer.layers.\1.norm2.', key) + key = re.sub(r'^transformer.layers.(\d+).ln_attn.', r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).ln_mlp.', r'transformer.layers.\1.norm2.', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', + r'transformer.layers.\1.mlp.fc1.', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', + r'transformer.layers.\1.mlp.fc2.', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + def key_mapping_attn(key): + key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.', + r'transformer.layers.\1.mixer.Wqkv.', key) + key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.', + r'transformer.layers.\1.mixer.out_proj.', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + n_head = config.n_head + n_head_kv = getattr(config, "n_head_kv", 1) + headdim = config.hidden_size // n_head + for l in range(config.n_layer): + # The weights are stored in a different layout compared to our implementation + Wqkv = rearrange(state_dict.pop(f'transformer.layers.{l}.mixer.Wqkv.weight'), + "(group ratio headdim) ... -> group ratio headdim ...", + ratio=n_head // n_head_kv + 2, headdim=headdim) + Wq = rearrange(Wqkv[:, :-2], "group ratio headdim ... -> (group ratio headdim) ...") + Wk = rearrange(Wqkv[:, [-2]], "group ratio headdim ... -> (group ratio headdim) ...") + Wv = rearrange(Wqkv[:, [-1]], "group ratio headdim ... -> (group ratio headdim) ...") + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) + + return state_dict + + +def falcon_config_to_gpt2_config(falcon_config: FalconConfig) -> GPT2Config: + # The 40b config uses "n_head_kv" instead of "num_kv_heads" + n_head_kv = getattr(falcon_config, "n_head_kv", + 1 if getattr(falcon_config, "multi_query", False) + else falcon_config.n_head) + # HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config. + # So we have to infer it from the number of heads in the key/value block + parallel_block_tied_norm = n_head_kv == 1 + return GPT2Config( + vocab_size=falcon_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=falcon_config.hidden_size, + n_layer=falcon_config.n_layer, + n_head=falcon_config.n_head, + n_inner=falcon_config.hidden_size * 4, + activation_function="gelu", + resid_pdrop=falcon_config.hidden_dropout, + embd_pdrop=0.0, # There doesn't seem to be any embedding dropout + attn_pdrop=falcon_config.attention_dropout, + layer_norm_epsilon=falcon_config.layer_norm_epsilon, + initializer_range=falcon_config.initializer_range, + bos_token_id=falcon_config.bos_token_id, + eos_token_id=falcon_config.eos_token_id, + # These are new arguments not in the original GPT2Config + parallel_block=falcon_config.parallel_attn, + n_head_kv=n_head_kv, + parallel_block_tied_norm=parallel_block_tied_norm, + rotary_emb_fraction=1.0, + rotary_emb_interleaved=False, + tie_word_embeddings=True, + qkv_proj_bias=falcon_config.bias, + out_proj_bias=falcon_config.bias, + mlp_fc1_bias=falcon_config.bias, + mlp_fc2_bias=falcon_config.bias, + lm_head_bias=False, + ) diff --git a/pkgs/xformers/_flash_attn/models/gpt.py b/pkgs/xformers/_flash_attn/models/gpt.py new file mode 100644 index 0000000..1dc3d0c --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/gpt.py @@ -0,0 +1,740 @@ +# Copyright (c) 2023, Tri Dao. + +import logging +import math +import re +from functools import partial + +from collections import namedtuple, OrderedDict +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers import GPT2Config + +from einops import rearrange + +from flash_attn.ops.activations import sqrelu_fwd +from flash_attn.modules.mha import MHA, ParallelMHA +from flash_attn.modules.mlp import Mlp, GatedMlp, ParallelMLP, FusedMLP, ParallelFusedMLP +from flash_attn.modules.block import Block, ParallelBlock +from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings +from flash_attn.utils.distributed import sync_shared_params, all_gather_raw +from flash_attn.utils.pretrained import state_dict_from_pretrained +from flash_attn.utils.generation import GenerationMixin +from flash_attn.models.opt import remap_state_dict_hf_opt +from flash_attn.models.gptj import remap_state_dict_hf_gptj +from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox +from flash_attn.models.falcon import remap_state_dict_hf_falcon + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear +except ImportError: + ColumnParallelLinear = None + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual +except ImportError: + dropout_add_layer_norm_parallel_residual = None + +try: + from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm +except ImportError: + RMSNorm, dropout_add_rms_norm = None, None + +try: + from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual +except ImportError: + dropout_add_rms_norm_parallel_residual = None + +try: + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense +except ImportError: + FusedDenseSqreluDense = None + + +logger = logging.getLogger(__name__) + + +def create_mixer_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + head_dim = getattr(config, 'head_dim', config.hidden_size // config.num_attention_heads) + softmax_scale = 1.0 if not config.scale_attn_weights else head_dim ** (-0.5) + if config.scale_attn_by_inverse_layer_idx: + assert layer_idx is not None + softmax_scale /= float(layer_idx + 1) + dwconv = getattr(config, 'attn_dwconv', False) + if dwconv: + assert process_group is None, 'TensorParallel MHA does not support dwconv yet' + qkv_proj_bias = getattr(config, 'qkv_proj_bias', True) + out_proj_bias = getattr(config, 'out_proj_bias', True) + rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim) + rotary_emb_base = getattr(config, 'rotary_emb_base', 10000.0) + rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', None) + rotary_emb_interleaved = getattr(config, 'rotary_emb_interleaved', False) + use_flash_attn = getattr(config, 'use_flash_attn', False) + fused_bias_fc = getattr(config, 'fused_bias_fc', False) + if not fused_bias_fc: + assert process_group is None, 'TensorParallel MHA requires fused_bias_fc' + mha_cls = MHA if process_group is None else ParallelMHA + serial_kwargs = ({'fused_bias_fc': fused_bias_fc, 'dwconv': dwconv} + if process_group is None else {}) + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) + num_heads_kv = getattr(config, "n_head_kv", None) + mixer_cls = partial(mha_cls, num_heads=config.num_attention_heads, + num_heads_kv=num_heads_kv, + qkv_proj_bias=qkv_proj_bias, out_proj_bias=out_proj_bias, + dropout=config.attn_pdrop, + softmax_scale=softmax_scale, causal=True, layer_idx=layer_idx, + rotary_emb_dim=rotary_emb_dim, rotary_emb_base=rotary_emb_base, + rotary_emb_scale_base=rotary_emb_scale_base, + rotary_emb_interleaved=rotary_emb_interleaved, + use_flash_attn=use_flash_attn, + **serial_kwargs, **parallel_kwargs, **factory_kwargs) + return mixer_cls + + +def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + mlp_fc1_bias = getattr(config, 'mlp_fc1_bias', True) + mlp_fc2_bias = getattr(config, 'mlp_fc2_bias', True) + fused_mlp = getattr(config, 'fused_mlp', False) + if fused_mlp: + assert config.activation_function in ['gelu_new', 'gelu_fast', 'gelu_approx', 'relu', 'sqrelu'] + fused_dense_sqrelu_dense = getattr(config, 'fused_dense_sqrelu_dense', False) + if fused_dense_sqrelu_dense: + assert config.activation_function == 'sqrelu', ('fused_dense_sqrelu_dense only ' + 'supports approximate activation_function sqrelu') + assert not (fused_dense_sqrelu_dense and fused_mlp) + if not fused_mlp and not fused_dense_sqrelu_dense: + assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', 'relu', + 'sqrelu', 'glu', 'swiglu', 'geglu'] + if config.activation_function in ['glu', 'swiglu', 'geglu']: + activation = (F.sigmoid if config.activation_function == 'glu' + else (F.silu if config.activation_function == 'swiglu' + else F.gelu)) + mlp_cls = partial(GatedMlp, hidden_features=config.n_inner, activation=activation, + bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, **factory_kwargs) + else: + if config.activation_function == 'relu': + activation = partial(F.relu, inplace=True) + elif config.activation_function == 'sqrelu': + activation = sqrelu_fwd + else: + approximate = ('tanh' if config.activation_function + in ['gelu_new', 'gelu_fast', 'gelu_approx'] else 'none') + activation=partial(F.gelu, approximate=approximate) + mlp_cls = Mlp if process_group is None else ParallelMLP + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) + mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, + bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, + **parallel_kwargs, **factory_kwargs) + else: + mlp_checkpoint_lvl = getattr(config, 'mlp_checkpoint_lvl', 0) + # mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer + if isinstance(mlp_checkpoint_lvl, Sequence): + assert layer_idx is not None + mlp_checkpoint_lvl = mlp_checkpoint_lvl[layer_idx] + if fused_mlp: + if FusedMLP is None: + raise ImportError('fused_dense is not installed') + activation = ('gelu_approx' if config.activation_function + in ['gelu_new', 'gelu_fast', 'gelu_approx'] else config.activation_function) + mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP + parallel_kwargs = ({'process_group': process_group, + 'sequence_parallel': getattr(config, 'sequence_parallel', True)} + if process_group is not None else {}) + mlp_cls = partial(mlp_cls, hidden_features=config.n_inner, activation=activation, + checkpoint_lvl=mlp_checkpoint_lvl, + bias1=mlp_fc1_bias, bias2=mlp_fc2_bias, + **parallel_kwargs, **factory_kwargs) + elif fused_dense_sqrelu_dense: + assert FusedDenseSqreluDense is not None + mlp_cls = partial(FusedDenseSqreluDense, hidden_features=config.n_inner, + checkpoint_lvl=mlp_checkpoint_lvl, **factory_kwargs) + else: + raise RuntimeError('MLP type not supported') + return mlp_cls + + +def create_block(config, layer_idx=None, process_group=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + sequence_parallel = getattr(config, 'sequence_parallel', True) + mixer_cls = create_mixer_cls(config, layer_idx, process_group=process_group, **factory_kwargs) + mlp_cls = create_mlp_cls(config, layer_idx, process_group=process_group, **factory_kwargs) + use_rms_norm = getattr(config, 'rms_norm', False) + norm_cls = partial(nn.LayerNorm if not use_rms_norm else RMSNorm, + eps=config.layer_norm_epsilon, **factory_kwargs) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + residual_in_fp32 = getattr(config, 'residual_in_fp32', False) + resid_dropout1 = config.resid_pdrop if layer_idx is None or layer_idx > 0 else config.embd_pdrop + prenorm = getattr(config, 'prenorm', True) + parallel_block = getattr(config, 'parallel_block', False) + if not parallel_block: + block = Block( + config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, + prenorm=prenorm, resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, + fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None + ) + else: + assert prenorm + block = ParallelBlock( + config.hidden_size, mixer_cls, mlp_cls, norm_cls=norm_cls, + resid_dropout1=resid_dropout1, resid_dropout2=config.resid_pdrop, + tied_norm=getattr(config, 'parallel_block_tied_norm', False), + fused_dropout_add_ln=getattr(config, 'fused_dropout_add_ln', False), + residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None + ) + block.layer_idx = layer_idx + return block + + +class GPTPreTrainedModel(nn.Module): + """ An abstract class to handle weights initialization and + a simple interface for dowloading and loading pretrained models. + """ + def __init__(self, config, *inputs, **kwargs): + super().__init__() + if not isinstance(config, GPT2Config): + raise ValueError( + "Parameter config in `{}(config)` should be an instance of class `GPT2Config`. " + "To create a model from a Google pretrained model use " + "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( + self.__class__.__name__, self.__class__.__name__ + )) + self.config = config + + @classmethod + def from_pretrained(cls, model_name, config, *args, strict=True, device=None, dtype=None, + world_size=1, rank=0, **kwargs): + """ + Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict. + Download and cache the pre-trained model file if needed. + """ + # Instantiate model. + model = cls(config, *args, device=device, dtype=dtype, **kwargs) + # Load state_dict in cpu because we already initialized the model in GPU, and we don't + # want extra stuff taking up more GPU memory + state_dict = state_dict_from_pretrained( + model_name, device='cpu', dtype=dtype + ) + if model_name.startswith('gpt2'): + state_dict = remap_state_dict_hf_gpt2(state_dict, config) + elif model_name.startswith('facebook/opt'): + state_dict = remap_state_dict_hf_opt(state_dict, config) + elif model_name.startswith('EleutherAI/gpt-j-'): + state_dict = remap_state_dict_hf_gptj(state_dict, config) + elif model_name.startswith('EleutherAI/gpt-neox-'): + state_dict = remap_state_dict_hf_gpt_neox(state_dict, config) + elif model_name.startswith('tiiuae/falcon-'): + state_dict = remap_state_dict_hf_falcon(state_dict, config) + else: + raise NotImplementedError(f'Model {model_name} not supported') + if world_size > 1: + state_dict = shard_state_dict_tp(state_dict, config, world_size, rank) + load_return = model.load_state_dict(state_dict, strict=strict) + logger.info(load_return) + return model + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + + +class GPTModel(GPTPreTrainedModel): + + def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): + super().__init__(config) + factory_kwargs = {'device': device, 'dtype': dtype} + self.process_group = process_group + self.sequence_parallel = getattr(config, 'sequence_parallel', True) + assert config.activation_function in ['gelu', 'gelu_new', 'gelu_fast', 'gelu_approx', + 'relu', 'sqrelu', 'glu', 'swiglu', 'geglu'] + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) + * pad_vocab_size_multiple) + # TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable + self.residual_in_fp32 = getattr(config, 'residual_in_fp32', False) + # These 2 options are for OPT-350m + self.prenorm = getattr(config, 'prenorm', True) + use_rms_norm = getattr(config, 'rms_norm', False) + word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) + # For GPT-J, GPT-NeoX + self.parallel_block = getattr(config, 'parallel_block', False) + + if process_group is None: + self.embeddings = GPT2Embeddings( + config.hidden_size, vocab_size, config.max_position_embeddings, + word_embed_proj_dim=word_embed_proj_dim, **factory_kwargs + ) + else: + self.embeddings = ParallelGPT2Embeddings( + config.hidden_size, vocab_size, config.max_position_embeddings, + process_group=process_group, sequence_parallel=self.sequence_parallel, + **factory_kwargs + ) + + # We change the order of dropout, residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.layers = nn.ModuleList([create_block(config, layer_idx=i, process_group=process_group, + **factory_kwargs) + for i in range(config.num_hidden_layers)]) + + self.fused_dropout_add_ln = getattr(config, 'fused_dropout_add_ln', False) + if self.fused_dropout_add_ln: + if ((not self.parallel_block and dropout_add_layer_norm is None) + or (self.parallel_block and dropout_add_layer_norm_parallel_residual is None)): + raise ImportError('dropout_layer_norm is not installed') + if self.prenorm: + self.drop_f = nn.Dropout(config.resid_pdrop) + norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm + self.ln_f = norm_cls(config.hidden_size, eps=config.layer_norm_epsilon, + **factory_kwargs) + if process_group is not None: + for p in self.ln_f.parameters(): + # Mark the norm parameters as "shared_params" so that we sync their values at init. + p._shared_params = True + # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. + if self.sequence_parallel: + p._sequence_parallel = True + + self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return {i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers)} + + def forward(self, input_ids, position_ids=None, inference_params=None): + # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen + # dimensions so that we can split on it easily, in case of small batch size. + # Only the attention layers need to know the seqlen. + embedding_kwargs = ({'combine_batch_seqlen_dim': True} + if self.process_group is not None and self.sequence_parallel else {}) + hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) + if self.parallel_block: + hidden_states2 = None + residual = None + mixer_kwargs = ({'seqlen': input_ids.shape[1]} + if self.process_group is not None and self.sequence_parallel else {}) + if inference_params is not None: + mixer_kwargs['inference_params'] = inference_params + for layer in self.layers: + if self.prenorm: + if not self.parallel_block: + hidden_states, residual = layer(hidden_states, residual, + mixer_kwargs=mixer_kwargs) + else: + hidden_states, hidden_states2, residual = layer( + hidden_states, hidden_states2, residual, mixer_kwargs=mixer_kwargs + ) + else: + hidden_states = layer(hidden_states, mixer_kwargs=mixer_kwargs) + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_f(hidden_states) + if not self.parallel_block: + residual = (dropped + residual) if residual is not None else dropped + else: + dropped2 = self.drop_f(hidden_states2) + residual = ((residual + dropped + dropped2) + if residual is not None else dropped + dropped2) + hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + if not self.parallel_block: + fused_add_norm_fn = (dropout_add_rms_norm if isinstance(self.ln_f, RMSNorm) + else dropout_add_layer_norm) + hidden_states = fused_add_norm_fn( + hidden_states, residual, self.ln_f.weight, self.ln_f.bias, + self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, + residual_in_fp32=self.residual_in_fp32 + ) + else: + fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual + if isinstance(self.ln_f, RMSNorm) + else dropout_add_layer_norm_parallel_residual) + hidden_states, _ = fused_add_norm_fn( + hidden_states, hidden_states2, residual, self.ln_f.weight, self.ln_f.bias, + None, None, self.drop_f.p if self.training else 0.0, self.ln_f.eps, + prenorm=False, residual_in_fp32=self.residual_in_fp32 + ) + return hidden_states + + +class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin): + + def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__(config) + self.process_group = process_group + self.transformer = GPTModel(config, process_group=process_group, **factory_kwargs) + self.tie_word_embeddings = getattr(config, 'tie_word_embeddings', True) + lm_head_bias = getattr(config, 'lm_head_bias', False) + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) + * pad_vocab_size_multiple) + # This option is for OPT-350m + word_embed_proj_dim = getattr(config, 'word_embed_proj_dim', None) + embed_dim = config.n_embd if word_embed_proj_dim is None else word_embed_proj_dim + if word_embed_proj_dim is not None: + self.project_out = nn.Linear(config.n_embd, embed_dim, bias=False, **factory_kwargs) + else: + self.project_out = None + if process_group is None: + self.lm_head = nn.Linear(embed_dim, vocab_size, bias=lm_head_bias, **factory_kwargs) + else: + if ColumnParallelLinear is None: + raise ImportError('fused_dense_lib is not installed') + self.lm_head = ColumnParallelLinear( + embed_dim, vocab_size, process_group, bias=lm_head_bias, + sequence_parallel=getattr(config, 'sequence_parallel', True), **factory_kwargs + ) + # Initialize weights and apply final processing + self.apply(partial(_init_weights, n_layer=config.num_hidden_layers, + initializer_range=config.initializer_range)) + self.tie_weights() + + def tie_weights(self): + if self.tie_word_embeddings: + self.lm_head.weight = self.transformer.embeddings.word_embeddings.weight + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.transformer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, + **kwargs) + + def forward(self, input_ids, position_ids=None, inference_params=None, last_token_only=False): + """ + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + last_token_only: whether to return the logit for the last token only, + of shape (batch_size, vocab_size) + """ + hidden_states = self.transformer(input_ids, position_ids=position_ids, + inference_params=inference_params) + if last_token_only: + hidden_states = hidden_states[:, -1] + if self.project_out is not None: + hidden_states = self.project_out(hidden_states) + lm_logits = self.lm_head(hidden_states) + # During inference, we want the full logit for sampling + if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: + lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) + lm_logits = rearrange(lm_logits, '(n b) ... d -> b ... (n d)', b=hidden_states.shape[0]) + CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) + return CausalLMOutput(logits=lm_logits) + + def load_state_dict(self, state_dict, strict=True): + # Remapping from our checkpoints that used a different ordering of layers in the block + # Previous: Attn / MLP -> Dropout -> Add -> LN + # Current: Dropout -> Add -> LN -> Attn / MLP + if 'transformer.ln_0.weight' in state_dict: + n_layers = len(self.transformer.layers) + ln_weight = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.weight') + ln_bias = state_dict.pop(f'transformer.layers.{n_layers - 1}.norm2.bias') + state_dict['transformer.ln_f.weight'] = ln_weight + state_dict['transformer.ln_f.bias'] = ln_bias + for l in reversed(range(n_layers)): + ln_weight = state_dict.pop(f'transformer.layers.{l}.norm1.weight') + ln_bias = state_dict.pop(f'transformer.layers.{l}.norm1.bias') + state_dict[f'transformer.layers.{l}.norm2.weight'] = ln_weight + state_dict[f'transformer.layers.{l}.norm2.bias'] = ln_bias + if l > 0: + ln_weight = state_dict.pop(f'transformer.layers.{l - 1}.norm2.weight') + ln_bias = state_dict.pop(f'transformer.layers.{l - 1}.norm2.bias') + state_dict[f'transformer.layers.{l}.norm1.weight'] = ln_weight + state_dict[f'transformer.layers.{l}.norm1.bias'] = ln_bias + ln_weight = state_dict.pop('transformer.ln_0.weight') + ln_bias = state_dict.pop('transformer.ln_0.bias') + state_dict[f'transformer.layers.0.norm1.weight'] = ln_weight + state_dict[f'transformer.layers.0.norm1.bias'] = ln_bias + return super().load_state_dict(state_dict, strict=strict) + + +def shard_state_dict_tp(state_dict, config, world_size, rank): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + def shard_first_dim(state_dict, key): + if key in state_dict: + x = state_dict[key] + dim = x.shape[0] // world_size + state_dict[key] = x[rank * dim:(rank + 1) * dim] + + def shard_last_dim(state_dict, key): + if key in state_dict: + x = state_dict[key] + dim = x.shape[-1] // world_size + state_dict[key] = x[..., rank * dim:(rank + 1) * dim] + + def shard_qkv_headdim(state_dict, key): + if key in state_dict: + n_head = config.n_head + n_head_kv = getattr(config, 'n_head_kv', n_head) + assert n_head % world_size == 0 and n_head_kv % world_size == 0 + if n_head_kv == n_head: + x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) + dim = x.shape[1] // world_size + state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], + 'three d ... -> (three d) ...') + else: + n_head_per_rank = n_head // world_size + n_head_kv_per_rank = n_head_kv // world_size + x = rearrange(state_dict[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...', + nheadqkv=n_head + 2 * n_head_kv) + state_dict[key] = rearrange(torch.cat([ + x[rank * n_head_per_rank:(rank + 1) * n_head_per_rank], + x[n_head + rank * n_head_kv_per_rank:n_head + (rank + 1) * n_head_kv_per_rank], + x[n_head + n_head_kv + rank * n_head_kv_per_rank:n_head + n_head_kv + (rank + 1) * n_head_kv_per_rank], + ], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...") + + shard_first_dim(state_dict, 'transformer.embeddings.word_embeddings.weight') + if 'lm_head.weight' in state_dict: + shard_first_dim(state_dict, 'lm_head.weight') + if 'transformer.embeddings.position_embeddings.weight' in state_dict: + shard_last_dim(state_dict, 'transformer.embeddings.position_embeddings.weight') + for i in range(config.num_hidden_layers): + shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') + shard_qkv_headdim(state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') + shard_last_dim(state_dict, f'transformer.layers.{i}.mixer.out_proj.weight') + if rank != 0: + state_dict.pop(f'transformer.layers.{i}.mixer.out_proj.bias', None) + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.weight') + shard_first_dim(state_dict, f'transformer.layers.{i}.mlp.fc1.bias') + shard_last_dim(state_dict, f'transformer.layers.{i}.mlp.fc2.weight') + if rank != 0: + state_dict.pop(f'transformer.layers.{i}.mlp.fc2.bias', None) + return state_dict + + +def combine_state_dicts_tp(state_dicts, config): + """Convert the state_dict of a standard GPT model to the state_dict of a GPT model + with tensor parallel. + """ + world_size = len(state_dicts) + keys = state_dicts[0].keys() + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + assert vocab_size % world_size == 0 + assert config.hidden_size % world_size == 0 + inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size + assert inner_dim % world_size == 0 + + # Sometimes the word embeddings are sharded on the 0th dim, sometimes on the 1st dim. + # vocab_size // world_size coordinates are nonzero. + def combine_word_embeddings(state_dicts, state_dict, key): + dim = 0 if state_dicts[0][key].shape[0] == vocab_size // world_size else 1 + state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) + + def combine_dim(state_dicts, state_dict, key, dim=-1): + if key in state_dict: + state_dict[key] = torch.cat([s[key] for s in state_dicts], dim=dim) + + def combine_qkv_headdim(state_dicts, state_dict, key): + n_head = config.n_head + n_head_kv = getattr(config, 'n_head_kv', n_head) + assert n_head % world_size == 0 and n_head_kv % world_size == 0 + n_head_per_rank = n_head // world_size + n_head_kv_per_rank = n_head_kv // world_size + if key in state_dict: + if n_head_kv == n_head: + xs = [rearrange(s[key], '(three d) ... -> three d ...', three=3) for s in state_dicts] + state_dict[key] = rearrange(torch.cat(xs, dim=1), 'three d ... -> (three d) ...') + else: + xs = [rearrange(s[key], '(nheadqkv headdim) ... -> nheadqkv headdim ...', + nheadqkv=n_head + 2 * n_head_kv) for s in state_dicts] + state_dict[key] = rearrange(torch.cat([ + torch.cat([x[:n_head_per_rank] for x in xs], dim=0), + torch.cat([x[n_head_per_rank:n_head_per_rank + n_head_kv_per_rank] for x in xs], dim=0), + torch.cat([x[-n_head_kv_per_rank:] for x in xs], dim=0), + ], dim=0), "nheadqkv headdim ... -> (nheadqkv headdim) ...") + + def combine_gated_mlp(state_dicts, state_dict, key): + if key in state_dict: + xs = [rearrange(s[key], '(two d) ... -> two d ...', two=2) for s in state_dicts] + state_dict[key] = rearrange(torch.cat(xs, dim=1), 'two d ... -> (two d) ...') + + state_dict = state_dicts[0].copy() # don't modify state_dict[0] inplace + combine_word_embeddings(state_dicts, state_dict, 'transformer.embeddings.word_embeddings.weight') + if 'lm_head.weight' in state_dict: + combine_word_embeddings(state_dicts, state_dict, 'lm_head.weight') + if 'transformer.embeddings.position_embeddings.weight' in state_dict: + combine_dim(state_dicts, state_dict, 'transformer.embeddings.position_embeddings.weight', -1) + mlp_combine_fn = (combine_gated_mlp if config.activation_function in ['glu', 'swiglu', 'geglu'] + else partial(combine_dim, dim=0)) + for i in range(config.num_hidden_layers): + combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.weight') + combine_qkv_headdim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.Wqkv.bias') + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mixer.out_proj.weight', -1) + mlp_combine_fn(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.weight') + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc1.bias', 0) + combine_dim(state_dicts, state_dict, f'transformer.layers.{i}.mlp.fc2.weight', -1) + return state_dict + + +def remap_state_dict_hf_gpt2(state_dict, config): + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('wte.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^ln_f.(weight|bias)', r'transformer.ln_f.\1', key) + key = re.sub(r'^h.(\d+).ln_(1|2).(weight|bias)', r'transformer.layers.\1.norm\2.\3', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for d in range(config.num_hidden_layers): + W1 = state_dict.pop(f'h.{d}.mlp.c_fc.weight') + state_dict[f'transformer.layers.{d}.mlp.fc1.weight'] = W1.t() + W2 = state_dict.pop(f'h.{d}.mlp.c_proj.weight') + state_dict[f'transformer.layers.{d}.mlp.fc2.weight'] = W2.t() + def key_mapping_mlp(key): + key = re.sub(r'^h.(\d+).mlp.c_fc.bias', r'transformer.layers.\1.mlp.fc1.bias', key) + key = re.sub(r'^h.(\d+).mlp.c_proj.bias', r'transformer.layers.\1.mlp.fc2.bias', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for d in range(config.num_hidden_layers): + state_dict.pop(f'h.{d}.attn.bias') # We don't store this bias + Wqkv = state_dict.pop(f'h.{d}.attn.c_attn.weight') + state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = Wqkv.t() + Wout = state_dict.pop(f'h.{d}.attn.c_proj.weight') + state_dict[f'transformer.layers.{d}.mixer.out_proj.weight'] = Wout.t() + def key_mapping_attn(key): + key = re.sub(r'^h.(\d+).attn.c_attn.bias', r'transformer.layers.\1.mixer.Wqkv.bias', key) + key = re.sub(r'^h.(\d+).attn.c_proj.bias', r'transformer.layers.\1.mixer.out_proj.bias', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def remap_state_dict_megatron(state_dict, config): + def key_mapping_transformer(key): + key = re.sub(r'^language_model.encoder.', 'transformer.', key) + key = re.sub(r'^language_model.', 'transformer.', key) + return key + state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) + # Word embedding and position embedding + def key_mapping_pos_emb(key): + return re.sub(r'^wpe.', 'transformer.embeddings.position_embeddings.', key) + state_dict = OrderedDict((key_mapping_pos_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embedding.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.final_layernorm.(weight|bias)', r'transformer.ln_f.\1', key) + key = re.sub(r'^transformer.layers.(\d+).input_layernorm.(weight|bias)', + r'transformer.layers.\1.norm1.\2', key) + key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.(weight|bias)', + r'transformer.layers.\1.norm2.\2', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.(weight|bias)', + r'transformer.layers.\1.mlp.fc1.\2', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.(weight|bias)', + r'transformer.layers.\1.mlp.fc2.\2', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + def key_mapping_attn(key): + key = re.sub(r'^transformer.layers.(\d+).self_attention.rotary_emb.inv_freq', + r'transformer.layers.\1.mixer.rotary_emb.inv_freq', key) + key = re.sub(r'^transformer.layers.(\d+).self_attention.query_key_value.(weight|bias)', + r'transformer.layers.\1.mixer.Wqkv.\2', key) + key = re.sub(r'^transformer.layers.(\d+).self_attention.dense.(weight|bias)', + r'transformer.layers.\1.mixer.out_proj.\2', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + # Megatron stores Wqkv as ((nheads 3 headdim), hidden_dim) + # while we store Wqkv as ((3 nheads headdim), hidden_dim) + headdim = config.hidden_size // config.num_attention_heads + for d in range(config.num_hidden_layers): + Wqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.weight') + state_dict[f'transformer.layers.{d}.mixer.Wqkv.weight'] = rearrange( + Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', + three=3, headdim=headdim + ) + bqkv = state_dict.pop(f'transformer.layers.{d}.mixer.Wqkv.bias') + state_dict[f'transformer.layers.{d}.mixer.Wqkv.bias'] = rearrange( + bqkv, '(nheads three headdim) -> (three nheads headdim)', + three=3, headdim=headdim + ) + + return state_dict diff --git a/pkgs/xformers/_flash_attn/models/gpt_neox.py b/pkgs/xformers/_flash_attn/models/gpt_neox.py new file mode 100644 index 0000000..9f23387 --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/gpt_neox.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from einops import rearrange + +from transformers import GPT2Config, GPTNeoXConfig + + +def remap_state_dict_hf_gpt_neox(state_dict, config): + def key_mapping_layers(key): + return re.sub(r'^gpt_neox.', 'transformer.', key) + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r'^transformer.embed_in.', 'transformer.embeddings.word_embeddings.', key) + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, 'tie_word_embeddings'): + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + else: + output_embeddings = state_dict.pop('embed_out.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict['lm_head.weight'] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) + key = re.sub(r'^transformer.layers.(\d+).input_layernorm.', r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).post_attention_layernorm.', r'transformer.layers.\1.norm2.', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_h_to_4h.', r'transformer.layers.\1.mlp.fc1.', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.dense_4h_to_h.', r'transformer.layers.\1.mlp.fc2.', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + # We don't store these biases + state_dict.pop(f'transformer.layers.{l}.attention.bias') + state_dict.pop(f'transformer.layers.{l}.attention.masked_bias') + # GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim) + # while we store Wqkv as ((3 nheads headdim), hidden_dim) + headdim = config.hidden_size // config.num_attention_heads + Wqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.weight') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = rearrange( + Wqkv, '(nheads three headdim) ... -> (three nheads headdim) ...', + three=3, headdim=headdim + ) + bqkv = state_dict.pop(f'transformer.layers.{l}.attention.query_key_value.bias') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = rearrange( + bqkv, '(nheads three headdim) -> (three nheads headdim)', + three=3, headdim=headdim + ) + def key_mapping_attn(key): + key = re.sub(r'^transformer.layers.(\d+).attention.dense.', + r'transformer.layers.\1.mixer.out_proj.', key) + key = re.sub(r'^transformer.layers.(\d+).attention.rotary_emb.', + r'transformer.layers.\1.mixer.rotary_emb.', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def gpt_neox_config_to_gpt2_config(gpt_neox_config: GPTNeoXConfig) -> GPT2Config: + assert gpt_neox_config.rotary_emb_base == 10000 + return GPT2Config( + vocab_size=gpt_neox_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=gpt_neox_config.hidden_size, + n_layer=gpt_neox_config.num_hidden_layers, + n_head=gpt_neox_config.num_attention_heads, + n_inner=gpt_neox_config.intermediate_size, + activation_function=gpt_neox_config.hidden_act, + resid_pdrop=0.0, # No dropout + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=gpt_neox_config.layer_norm_eps, + initializer_range=gpt_neox_config.initializer_range, + bos_token_id=gpt_neox_config.bos_token_id, + eos_token_id=gpt_neox_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=True, + parallel_block=gpt_neox_config.use_parallel_residual, + parallel_block_tied_norm=False, + rotary_emb_fraction=gpt_neox_config.rotary_pct, + tie_word_embeddings=gpt_neox_config.tie_word_embeddings, + ) diff --git a/pkgs/xformers/_flash_attn/models/gptj.py b/pkgs/xformers/_flash_attn/models/gptj.py new file mode 100644 index 0000000..4ab3063 --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/gptj.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from transformers import GPT2Config, GPTJConfig + + +def remap_state_dict_hf_gptj(state_dict, config): + def key_mapping_layers(key): + return re.sub(r'^transformer.h.', 'transformer.layers.', key) + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r'^transformer.wte.', 'transformer.embeddings.word_embeddings.', key) + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, 'tie_word_embeddings'): + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + else: + output_embeddings = state_dict.pop('lm_head.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict['lm_head.weight'] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + output_embeddings_bias = state_dict.pop('lm_head.bias') + state_dict['lm_head.bias'] = F.pad( + output_embeddings_bias, (0, vocab_size - output_embeddings_bias.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + return re.sub(r'^transformer.layers.(\d+).ln_1.', r'transformer.layers.\1.norm1.', key) + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + key = re.sub(r'^transformer.layers.(\d+).mlp.fc_in.', r'transformer.layers.\1.mlp.fc1.', key) + key = re.sub(r'^transformer.layers.(\d+).mlp.fc_out.', r'transformer.layers.\1.mlp.fc2.', key) + return key + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f'transformer.layers.{l}.attn.q_proj.weight') + Wk = state_dict.pop(f'transformer.layers.{l}.attn.k_proj.weight') + Wv = state_dict.pop(f'transformer.layers.{l}.attn.v_proj.weight') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) + # We don't store these biases + state_dict.pop(f'transformer.layers.{l}.attn.bias') + state_dict.pop(f'transformer.layers.{l}.attn.masked_bias') + def key_mapping_attn(key): + return re.sub(r'^transformer.layers.(\d+).attn.out_proj.', + r'transformer.layers.\1.mixer.out_proj.', key) + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def gptj_config_to_gpt2_config(gptj_config: GPTJConfig) -> GPT2Config: + headdim = gptj_config.n_embd // gptj_config.n_head + return GPT2Config( + vocab_size=gptj_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=gptj_config.n_embd, + n_layer=gptj_config.n_layer, + n_head=gptj_config.n_head, + n_inner=gptj_config.n_inner, + activation_function=gptj_config.activation_function, + resid_pdrop=gptj_config.resid_pdrop, + embd_pdrop=gptj_config.embd_pdrop, + attn_pdrop=gptj_config.attn_pdrop, + layer_norm_epsilon=gptj_config.layer_norm_epsilon, + initializer_range=gptj_config.initializer_range, + bos_token_id=gptj_config.bos_token_id, + eos_token_id=gptj_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=True, + parallel_block=True, + parallel_block_tied_norm=True, + rotary_emb_fraction=gptj_config.rotary_dim / headdim, + rotary_emb_interleaved=True, + tie_word_embeddings=False, + qkv_proj_bias=False, + out_proj_bias=False, + lm_head_bias=True, + ) diff --git a/pkgs/xformers/_flash_attn/models/llama.py b/pkgs/xformers/_flash_attn/models/llama.py new file mode 100644 index 0000000..0a1702a --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/llama.py @@ -0,0 +1,124 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import json +import re +from pathlib import Path + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from transformers import GPT2Config, LlamaConfig + + +def remap_state_dict_meta_llama(state_dict, config): + def key_mapping_layers(key): + return f'transformer.{key}' if not key.startswith('output.') else key + state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items()) + # Word embedding + def key_mapping_emb(key): + return re.sub(r'^transformer.tok_embeddings.', 'transformer.embeddings.word_embeddings.', key) + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(word_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + if getattr(config, 'tie_word_embeddings'): + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + else: + output_embeddings = state_dict.pop('output.weight') + # Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings + # differently. + vocab_size = (math.ceil(output_embeddings.shape[0] / pad_vocab_size_multiple) + * pad_vocab_size_multiple) + # It's possible that vocab_size is padded to be a multiple of 8, for example. + state_dict['lm_head.weight'] = F.pad( + output_embeddings, (0, 0, 0, vocab_size - output_embeddings.shape[0]) + ) + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.norm.', r'transformer.ln_f.', key) + key = re.sub(r'^transformer.layers.(\d+).attention_norm.', r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).ffn_norm.', r'transformer.layers.\1.norm2.', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + for l in range(config.n_layer): + w1 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w1.weight') + w3 = state_dict.pop(f'transformer.layers.{l}.feed_forward.w3.weight') + # Our ordering is different + state_dict[f'transformer.layers.{l}.mlp.fc1.weight'] = torch.cat([w3, w1], dim=0) + def key_mapping_mlp(key): + return re.sub(r'^transformer.layers.(\d+).feed_forward.w2.', + r'transformer.layers.\1.mlp.fc2.', key) + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f'transformer.layers.{l}.attention.wq.weight') + Wk = state_dict.pop(f'transformer.layers.{l}.attention.wk.weight') + Wv = state_dict.pop(f'transformer.layers.{l}.attention.wv.weight') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) + # We don't store these + state_dict.pop(f'transformer.layers.{l}.attention.inner_attention.rope.freqs', None) + def key_mapping_attn(key): + return re.sub(r'^transformer.layers.(\d+).attention.wo.', + r'transformer.layers.\1.mixer.out_proj.', key) + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def config_from_checkpoint(checkpoint_path: str, model_name: str) -> LlamaConfig: + """Load a LlamaConfig from a checkpoint path.""" + with open(Path(checkpoint_path) / model_name / 'params.json') as f: + params = json.load(f) + config = LlamaConfig(hidden_size=params['dim'], intermediate_size=None, + num_attention_heads=params['n_heads'], + num_hidden_layers=params['n_layers'], + rms_norm_eps=params['norm_eps']) + return config + + +def state_dicts_from_checkpoint(checkpoint_path: str, model_name: str) -> dict: + # Need to sort, otherwise we mess up the ordering and the weights are wrong + return [torch.load(path, map_location='cpu') + for path in sorted((Path(checkpoint_path) / model_name).glob('consolidated.*.pth'))] + + +def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config: + return GPT2Config( + vocab_size=llama_config.vocab_size, + n_positions=0, # No absolute position embedding + n_embd=llama_config.hidden_size, + n_layer=llama_config.num_hidden_layers, + n_head=llama_config.num_attention_heads, + n_inner=llama_config.intermediate_size, + activation_function='swiglu', # Hardcode since HF calls it 'silu' + # Llama doesn't have dropout, idk if it's because they only release the inference code + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0, + layer_norm_epsilon=llama_config.rms_norm_eps, + initializer_range=llama_config.initializer_range, + bos_token_id=llama_config.bos_token_id, + eos_token_id=llama_config.eos_token_id, + # These are new arguments not in the original GPT2Config + pad_token_id=llama_config.pad_token_id, # Idk if this does anything + rms_norm=True, + rotary_emb_fraction=1.0, + rotary_emb_interleaved=True, + tie_word_embeddings=False, + qkv_proj_bias=False, + out_proj_bias=False, + mlp_fc1_bias=False, + mlp_fc2_bias=False, + ) diff --git a/pkgs/xformers/_flash_attn/models/opt.py b/pkgs/xformers/_flash_attn/models/opt.py new file mode 100644 index 0000000..63844ba --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/opt.py @@ -0,0 +1,102 @@ +# Copyright (c) 2023, Tri Dao. + +import math +import re + +from collections import OrderedDict + +import torch +import torch.nn.functional as F + +from transformers import GPT2Config, OPTConfig + + +def remap_state_dict_hf_opt(state_dict, config): + def key_mapping_model(key): + key = re.sub(r'^model.decoder.', 'transformer.', key) + # The OPT-350m model uses '^decoder' instead of '^model.decoder' + key = re.sub(r'^decoder.', 'transformer.', key) + return key + state_dict = OrderedDict((key_mapping_model(k), v) for k, v in state_dict.items()) + # Word embedding and position embedding + def key_mapping_emb(key): + key = re.sub(r'^transformer.embed_tokens.', 'transformer.embeddings.word_embeddings.', key) + # The OPT-350m model uses has project_in and project_out + key = re.sub(r'^transformer.project_in.', 'transformer.embeddings.project_in.', key) + key = re.sub(r'^transformer.project_out.', 'project_out.', key) + key = re.sub(r'^transformer.embed_positions.', + 'transformer.embeddings.position_embeddings.', key) + return key + state_dict = OrderedDict((key_mapping_emb(k), v) for k, v in state_dict.items()) + # OPT uses the first 2 indices of pos_emb for padding tokens + pos_embeddings = state_dict.pop('transformer.embeddings.position_embeddings.weight') + state_dict['transformer.embeddings.position_embeddings.weight'] = pos_embeddings[2:] + word_embeddings = state_dict.pop('transformer.embeddings.word_embeddings.weight') + # It's possible that vocab_size is padded to be a multiple of 8, for example. + pad_vocab_size_multiple = getattr(config, 'pad_vocab_size_multiple', 1) + vocab_size = (math.ceil(config.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + state_dict['transformer.embeddings.word_embeddings.weight'] = F.pad( + word_embeddings, (0, 0, 0, vocab_size - word_embeddings.shape[0]) + ) + state_dict['lm_head.weight'] = state_dict['transformer.embeddings.word_embeddings.weight'] + + # LayerNorm + def key_mapping_ln(key): + key = re.sub(r'^transformer.final_layer_norm.', r'transformer.ln_f.', key) + # The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm' + key = re.sub(r'^transformer.layer_norm.', r'transformer.ln_f.', key) + key = re.sub(r'^transformer.layers.(\d+).self_attn_layer_norm.', + r'transformer.layers.\1.norm1.', key) + key = re.sub(r'^transformer.layers.(\d+).final_layer_norm.', + r'transformer.layers.\1.norm2.', key) + return key + state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items()) + + # MLP + def key_mapping_mlp(key): + return re.sub(r'^transformer.layers.(\d+).fc(1|2).', + r'transformer.layers.\1.mlp.fc\2.', key) + state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items()) + + # Attention + for l in range(config.n_layer): + Wq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.weight') + Wk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.weight') + Wv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.weight') + bq = state_dict.pop(f'transformer.layers.{l}.self_attn.q_proj.bias') + bk = state_dict.pop(f'transformer.layers.{l}.self_attn.k_proj.bias') + bv = state_dict.pop(f'transformer.layers.{l}.self_attn.v_proj.bias') + state_dict[f'transformer.layers.{l}.mixer.Wqkv.weight'] = torch.cat([Wq, Wk, Wv], dim=0) + state_dict[f'transformer.layers.{l}.mixer.Wqkv.bias'] = torch.cat([bq, bk, bv], dim=0) + def key_mapping_attn(key): + return re.sub(r'^transformer.layers.(\d+).self_attn.out_proj.', + r'transformer.layers.\1.mixer.out_proj.', key) + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + + return state_dict + + +def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config: + assert opt_config.layerdrop == 0.0 + assert opt_config.layer_norm_elementwise_affine + word_embed_proj_dim = (None if opt_config.word_embed_proj_dim == opt_config.hidden_size + else opt_config.word_embed_proj_dim) + return GPT2Config( + vocab_size=opt_config.vocab_size, + n_positions=opt_config.max_position_embeddings, + n_embd=opt_config.hidden_size, + n_layer=opt_config.num_hidden_layers, + n_head=opt_config.num_attention_heads, + n_inner=opt_config.ffn_dim, + activation_function=opt_config.activation_function, + resid_pdrop=opt_config.dropout, + # HF's implementation of OPT doesn't seem to have embedding dropout + embd_pdrop=opt_config.dropout, + attn_pdrop=opt_config.attention_dropout, + initializer_range=opt_config.init_std, + bos_token_id=opt_config.bos_token_id, + eos_token_id=opt_config.eos_token_id, + # These are new arguments not in the original GPT2Config + prenorm=opt_config.do_layer_norm_before, + word_embed_proj_dim=word_embed_proj_dim + ) diff --git a/pkgs/xformers/_flash_attn/models/vit.py b/pkgs/xformers/_flash_attn/models/vit.py new file mode 100644 index 0000000..312f5e4 --- /dev/null +++ b/pkgs/xformers/_flash_attn/models/vit.py @@ -0,0 +1,304 @@ +# Copyright (c) 2022, Tri Dao. +# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +import math +import re +from functools import partial +from copy import deepcopy + +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ + +from torchvision.ops import StochasticDepth + +from einops import rearrange + +from timm.models.helpers import named_apply +from flash_attn.layers.patch_embed import PatchEmbed + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp, FusedMLP +from flash_attn.modules.block import Block + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + + +def create_mixer_cls(num_heads, qkv_bias, attn_drop, use_flash_attn, fused_bias_fc, + cross_attn=False): + mixer_cls = partial(MHA, num_heads=num_heads, cross_attn=cross_attn, bias=qkv_bias, + dropout=attn_drop, fused_bias_fc=fused_bias_fc, + use_flash_attn=use_flash_attn) + return mixer_cls + + +def create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp): + inner_dim = int(embed_dim * mlp_ratio) + if not fused_mlp: + mlp_cls = partial(Mlp, hidden_features=inner_dim, activation=act_layer()) + else: + mlp_cls = partial(FusedMLP, hidden_features=inner_dim) + return mlp_cls + + +def create_block(embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, + drop_path1, drop_path2, norm_layer, act_layer, use_flash_attn, fused_bias_fc, + fused_mlp, fused_dropout_add_ln, layer_idx=None, n_layer=None, + last_layer_subset=False): + mixer_cls = create_mixer_cls(num_heads, qkv_bias, attn_drop_rate, use_flash_attn, fused_bias_fc, + cross_attn=(last_layer_subset and layer_idx == n_layer - 1)) + mlp_cls = create_mlp_cls(embed_dim, mlp_ratio, act_layer, fused_mlp) + # TD [2022-10-15]: Force residual in fp32 in case of DeepSpeed + block = Block(embed_dim, mixer_cls, mlp_cls, norm_cls=norm_layer, + prenorm=True, resid_dropout1=drop_rate, resid_dropout2=drop_rate, + drop_path1=drop_path1, drop_path2=drop_path2, + fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=True) + return block + + +class VisionTransformer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + global_pool='token', + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4., + qkv_bias=True, + init_values=None, + class_token=True, + no_embed_class=False, + pre_norm=False, + fc_norm=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + weight_init='', + embed_layer=PatchEmbed, + norm_layer=None, + act_layer=None, + use_flash_attn=False, + fused_bias_fc=False, + fused_mlp=False, + fused_dropout_add_ln=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool == 'token', 'Only support pooling with CLS token' + assert class_token + assert init_values is None, 'LayerScale is not supported yet' + assert weight_init == '' + assert fc_norm is None + # pre_norm seems redundant, as there's a LayerNorm right at the start of each block, idk + assert not pre_norm + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class + + patch_embed_extra_kwargs = ({'fused_bias_fc': fused_bias_fc} if embed_layer is PatchEmbed + else {}) + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP) + **patch_embed_extra_kwargs + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + # We change the order of dropout, residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.blocks = nn.ModuleList([create_block( + embed_dim, num_heads, mlp_ratio, qkv_bias, drop_rate, attn_drop_rate, + drop_path1=dpr[i-1] if i > 0 else 0., drop_path2=dpr[i], + norm_layer=norm_layer, act_layer=act_layer, use_flash_attn=use_flash_attn, + fused_bias_fc=fused_bias_fc, fused_mlp=fused_mlp, + fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, n_layer=depth, + last_layer_subset=(global_pool == 'token') + ) for i in range(depth)]) + + self.dropout = nn.Dropout(p=drop_rate) + self.drop_path = StochasticDepth(p=dpr[-1], mode='row') + self.norm = norm_layer(embed_dim) + + self.fused_dropout_add_ln = fused_dropout_add_ln + if self.fused_dropout_add_ln and dropout_add_layer_norm is None: + raise ImportError('dropout_add_layer_norm is not installed') + + # Classifier Head + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode == '' + trunc_normal_(self.pos_embed, std=.02) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + init_weights_vit_timm(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return x + + def forward_features(self, x, all_tokens=True): + """ + If all_tokens==False and self.global_pool == 'token', we only return the features for the + cls token. + """ + x = self.patch_embed(x) + hidden_states = self._pos_embed(x) + residual = None + if self.global_pool != 'token' or all_tokens: + # if True: + for block in self.blocks: + hidden_states, residual = block(hidden_states, residual) + else: + for block in self.blocks[:-1]: + hidden_states, residual = block(hidden_states, residual) + # For the last layer, we only want the 1st token of the output. So we do cross-attention + # where the query is the 1st token and the key/value is the whole sequence. + hidden_states, residual = self.blocks[-1](hidden_states, residual, + mixer_subset=slice(0, 1)) + if not self.fused_dropout_add_ln: + residual = self.drop_path(self.dropout(hidden_states)) + residual + hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) + else: + if self.drop_path.p == 0 or not self.training: + rowscale = None + else: + rowscale = self.drop_path(torch.ones( + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) + ) + # Set prenorm=False here since we don't need to the residual + hidden_states = dropout_add_layer_norm( + hidden_states, residual, self.norm.weight, self.norm.bias, + self.dropout.p if self.training else 0.0, self.norm.eps, rowscale=rowscale, + prenorm=False, residual_in_fp32=True + ) + return hidden_states + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x, all_tokens=False) + x = self.forward_head(x) + return x + + def load_state_dict(self, state_dict, strict=True): + patch_embed_weight = state_dict['patch_embed.proj.weight'] + if patch_embed_weight.dim() == 4: + # convert from Conv2d to Linear + state_dict['patch_embed.proj.weight'] = rearrange(patch_embed_weight, + 'o c h w -> o (c h w)') + def key_mapping_attn(key): + key = re.sub(r'^blocks.(\d+).attn.qkv.', r'blocks.\1.mixer.Wqkv.', key) + key = re.sub(r'^blocks.(\d+).attn.proj.', r'blocks.\1.mixer.out_proj.', key) + return key + state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items()) + n_layer = len(self.blocks) + # Convert from Wqkv to Wq and Wkv for cross attention (last layer) + if (self.blocks[-1].mixer.cross_attn + and f'blocks.{n_layer - 1}.mixer.Wqkv.weight' in state_dict): + Wqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.weight') + bqkv = state_dict.pop(f'blocks.{n_layer - 1}.mixer.Wqkv.bias') + state_dict[f'blocks.{n_layer - 1}.mixer.Wq.weight'] = Wqkv[:self.embed_dim] + state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.weight'] = Wqkv[self.embed_dim:] + state_dict[f'blocks.{n_layer - 1}.mixer.Wq.bias'] = bqkv[:self.embed_dim] + state_dict[f'blocks.{n_layer - 1}.mixer.Wkv.bias'] = bqkv[self.embed_dim:] + return super().load_state_dict(state_dict, strict=strict) + + +def init_weights_vit_timm(module: nn.Module, name: str = ''): + """ ViT weight initialization, original timm impl (for reproducibility) """ + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + assert not pretrained + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = VisionTransformer(**model_kwargs) + return model diff --git a/pkgs/xformers/_flash_attn/modules/__init__.py b/pkgs/xformers/_flash_attn/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers/_flash_attn/modules/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3062d1a Binary files /dev/null and b/pkgs/xformers/_flash_attn/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/modules/__pycache__/block.cpython-310.pyc b/pkgs/xformers/_flash_attn/modules/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000..9251f00 Binary files /dev/null and b/pkgs/xformers/_flash_attn/modules/__pycache__/block.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/modules/__pycache__/embedding.cpython-310.pyc b/pkgs/xformers/_flash_attn/modules/__pycache__/embedding.cpython-310.pyc new file mode 100644 index 0000000..b3cd556 Binary files /dev/null and b/pkgs/xformers/_flash_attn/modules/__pycache__/embedding.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/modules/__pycache__/mha.cpython-310.pyc b/pkgs/xformers/_flash_attn/modules/__pycache__/mha.cpython-310.pyc new file mode 100644 index 0000000..9408f28 Binary files /dev/null and b/pkgs/xformers/_flash_attn/modules/__pycache__/mha.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/modules/__pycache__/mlp.cpython-310.pyc b/pkgs/xformers/_flash_attn/modules/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000..25c65d2 Binary files /dev/null and b/pkgs/xformers/_flash_attn/modules/__pycache__/mlp.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/modules/block.py b/pkgs/xformers/_flash_attn/modules/block.py new file mode 100644 index 0000000..af856b6 --- /dev/null +++ b/pkgs/xformers/_flash_attn/modules/block.py @@ -0,0 +1,324 @@ +# Copyright (c) 2022, Tri Dao. + +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from torchvision.ops import StochasticDepth + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual +except ImportError: + dropout_add_layer_norm_parallel_residual = None + +try: + from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm +except ImportError: + RMSNorm, dropout_add_rms_norm = None, None + +try: + from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual +except ImportError: + dropout_add_rms_norm_parallel_residual = None + + +class Block(nn.Module): + + def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0., + drop_path1=0., drop_path2=0., fused_dropout_add_ln=False, return_residual=False, + residual_in_fp32=False, sequence_parallel=False, mark_shared_params=False): + """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. + This is for performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True' + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode='row') + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode='row') + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm is not None, 'dropout_layer_norm is not installed' + assert dropout_add_rms_norm is not None, 'dropout_layer_norm is not installed' + assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) + and isinstance(self.dropout1, nn.Dropout)) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._shared_params = True + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, + mixer_subset=None, mixer_kwargs=None): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + fused_add_norm_fn = (dropout_add_rms_norm if RMSNorm and isinstance(self.norm1, RMSNorm) + else dropout_add_layer_norm) + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) + ) + hidden_states, residual = fused_add_norm_fn( + hidden_states, residual, self.norm1.weight, self.norm1.bias, + self.dropout1.p if self.training else 0.0, self.norm1.eps, + rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32 + ) + if mixer_kwargs is None: + mixer_kwargs = {} + if mixer_subset is not None: + mixer_kwargs['mixer_subset'] = mixer_subset + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + if mixer_subset is not None: + residual = residual[:, mixer_subset] + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + hidden_states.shape[:-1], device=hidden_states.device, + dtype=hidden_states.dtype) + ) + hidden_states, residual = fused_add_norm_fn( + hidden_states, residual, self.norm2.weight, self.norm2.bias, + self.dropout2.p if self.training else 0.0, self.norm2.eps, + rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32 + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + else: + assert residual is None + mixer_out = self.mixer( + hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) + ) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + if not self.fused_dropout_add_ln: + hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) + + hidden_states).to(dtype=self.norm1.weight.dtype)) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1(torch.ones( + mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype) + ) + hidden_states = fused_add_norm_fn( + mixer_out, hidden_states, self.norm1.weight, self.norm1.bias, + self.dropout1.p if self.training else 0.0, self.norm1.eps, + rowscale=rowscale1, prenorm=False + ) + if not isinstance(self.mlp, nn.Identity): + mlp_out = self.mlp(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + if not self.fused_dropout_add_ln: + hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) + + hidden_states).to(dtype=self.norm2.weight.dtype)) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2(torch.ones( + mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype) + ) + hidden_states = fused_add_norm_fn( + mlp_out, hidden_states, self.norm2.weight, self.norm2.bias, + self.dropout2.p if self.training else 0.0, self.norm2.eps, + rowscale=rowscale2, prenorm=False + ) + return hidden_states + + +class ParallelBlock(nn.Module): + """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, + and PaLM. + """ + + def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, resid_dropout1=0., resid_dropout2=0., + tied_norm=False, fused_dropout_add_ln=False, residual_in_fp32=False, + sequence_parallel=False, mark_shared_params=False): + """ + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA / MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both + the hidden_states (output1 of the MHA / MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.tied_norm = tied_norm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.residual_in_fp32 = residual_in_fp32 + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + self.dropout2 = dropout_cls(resid_dropout2) + if not self.tied_norm: + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert dropout_add_layer_norm_parallel_residual is not None, 'dropout_layer_norm is not installed' + assert dropout_add_rms_norm_parallel_residual is not None, 'dropout_layer_norm is not installed' + assert (isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) + and isinstance(self.dropout1, nn.Dropout)) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, 'norm2'): + for p in self.norm2.parameters(): + p._shared_params = True + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, hidden_states1: Tensor, hidden_states2: Optional[Tensor] = None, + residual: Optional[Tensor] = None, mixer_kwargs=None): + r"""Pass the input through the encoder layer. + + Args: + hidden_states1: the output of the previous attention (mixer) or embedding layer. + hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). + residual. + """ + # TODO: Ideally we should only do the allgather / allreduce once for + # the Linear to MLP & Attention + fused_add_norm_fn = (dropout_add_rms_norm_parallel_residual + if isinstance(self.norm1, RMSNorm) + else dropout_add_layer_norm_parallel_residual) + if not self.fused_dropout_add_ln: + dropped1 = self.dropout1(hidden_states1) + # For the very 1st block, we only want 1 dropout, not two different dropouts + if hidden_states2 is not None: + dropped2 = self.dropout2(hidden_states2) + residual = ((residual + dropped1 + dropped2) + if residual is not None else dropped1 + dropped2) + else: + residual = (residual + dropped1) if residual is not None else dropped1 + hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + hidden_states2 = (self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if not self.tied_norm else hidden_states1) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + weight2, bias2 = ((self.norm2.weight, self.norm2.bias) + if not self.tied_norm else (None, None)) + hidden_states1, hidden_states2, residual = fused_add_norm_fn( + hidden_states1, hidden_states2, residual, self.norm1.weight, self.norm1.bias, + weight2, bias2, self.dropout1.p if self.training else 0.0, self.norm1.eps, + prenorm=True, residual_in_fp32=self.residual_in_fp32 + ) + if self.tied_norm: + hidden_states2 = hidden_states1 + if mixer_kwargs is None: + mixer_kwargs = {} + hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) + hidden_states2 = self.mlp(hidden_states2) + return hidden_states1, hidden_states2, residual diff --git a/pkgs/xformers/_flash_attn/modules/embedding.py b/pkgs/xformers/_flash_attn/modules/embedding.py new file mode 100644 index 0000000..6a5da2d --- /dev/null +++ b/pkgs/xformers/_flash_attn/modules/embedding.py @@ -0,0 +1,183 @@ +# Copyright (c) 2022, Tri Dao. + +import torch +import torch.nn as nn +from torch import Tensor + +from einops import rearrange + +from flash_attn.utils.distributed import reduce_scatter, all_reduce + + +class GPT2Embeddings(nn.Module): + + def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, + word_embed_proj_dim=None, device=None, dtype=None): + """ + If max_position_embeddings <= 0, there's no position embeddings + If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension + the project up to embed_dim + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if word_embed_proj_dim is None: + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, + **factory_kwargs) + self.project_in = None + else: + self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, + padding_idx=padding_idx, **factory_kwargs) + self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, + **factory_kwargs) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, + **factory_kwargs) + + def forward(self, input_ids, position_ids=None): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + embeddings = self.word_embeddings(input_ids) + if self.project_in is not None: + embeddings = self.project_in(embeddings) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + return embeddings + + +class BertEmbeddings(nn.Module): + + def __init__(self, embed_dim, vocab_size, max_position_embeddings, type_vocab_size, + padding_idx=None, device=None, dtype=None): + """ + If max_position_embeddings <= 0, there's no position embeddings + If type_vocab_size <= 0, there's no token type embeddings + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, + **factory_kwargs) + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, + **factory_kwargs) + if self.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, + **factory_kwargs) + + def forward(self, input_ids, position_ids=None, token_type_ids=None): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + token_type_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + if self.type_vocab_size > 0: + if token_type_ids is None: + token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + + +class VocabParallelEmbedding(nn.Embedding): + + def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if num_embeddings % world_size != 0: + raise ValueError(f'num_embeddings ({num_embeddings}) must be divisible by ' + f'world_size ({world_size})') + if world_size > 1 and padding_idx is not None: + raise RuntimeError('ParallelEmbedding does not support padding_idx') + else: + world_size = 1 + super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.process_group is None: + return super().forward(input) + else: + rank = torch.distributed.get_rank(self.process_group) + vocab_size = self.num_embeddings + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + # Create a mask of valid vocab ids (1 means it needs to be masked). + input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) + input = input - vocab_start_index + input[input_ids_mask] = 0 + embeddings = super().forward(input) + embeddings[input_ids_mask] = 0.0 + return embeddings + + +class ColumnParallelEmbedding(nn.Embedding): + + def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if embedding_dim % world_size != 0: + raise ValueError(f'embedding_dim ({embedding_dim}) must be divisible by ' + f'world_size ({world_size})') + else: + world_size = 1 + super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) + + +class ParallelGPT2Embeddings(nn.Module): + + def __init__(self, embed_dim, vocab_size, max_position_embeddings, process_group, + padding_idx=None, sequence_parallel=True, device=None, dtype=None): + """ + If max_position_embeddings <= 0, there's no position embeddings + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.word_embeddings = VocabParallelEmbedding( + vocab_size, embed_dim, padding_idx=padding_idx, process_group=process_group, + **factory_kwargs + ) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = ColumnParallelEmbedding( + max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs + ) + + def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + world_size = torch.distributed.get_world_size(self.process_group) + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + if world_size <= 1: + embeddings = embeddings + position_embeddings + else: + partition_dim = self.position_embeddings.embedding_dim + rank = torch.distributed.get_rank(self.process_group) + embeddings[..., rank * partition_dim:(rank + 1) * partition_dim] += position_embeddings + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, 'b s d -> (b s) d') + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) diff --git a/pkgs/xformers/_flash_attn/modules/mha.py b/pkgs/xformers/_flash_attn/modules/mha.py new file mode 100644 index 0000000..43bff90 --- /dev/null +++ b/pkgs/xformers/_flash_attn/modules/mha.py @@ -0,0 +1,711 @@ +# Copyright (c) 2022, Tri Dao. + +import math +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat + +try: + from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func + from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func +except ImportError: + flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None + flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None + +try: + from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear +except ImportError: + FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None + +try: + from flash_attn.layers.rotary import RotaryEmbedding +except ImportError: + RotaryEmbedding = None + +try: + import ft_attention +except ImportError: + ft_attention = None + + +class FlashSelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed' + assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed' + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. + If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). + If cu_seqlens is not None and max_seqlen is not None, then qkv has shape + (total, 3, H, D), where total is the sum of the sequence lengths in the batch. + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + Returns: + -------- + out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, + else (B, S, H, D). + """ + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + causal = self.causal if causal is None else causal + unpadded = cu_seqlens is not None + if unpadded: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + return flash_attn_varlen_qkvpacked_func( + qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + else: + return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal) + + +class FlashCrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed' + assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed' + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None, + cu_seqlens_k=None, max_seqlen_k=None): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + max_seqlen: int. Maximum sequence length in the batch of q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_k: int. Maximum sequence length in the batch of k and v. + """ + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda and kv.is_cuda + causal = self.causal if causal is None else causal + unpadded = cu_seqlens is not None + if unpadded: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + assert cu_seqlens_k is not None + assert cu_seqlens_k.dtype == torch.int32 + assert max_seqlen_k is not None + assert isinstance(max_seqlen, int) + return flash_attn_varlen_kvpacked_func( + q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k, + self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, causal=causal + ) + else: + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0, + causal=causal, softmax_scale=self.softmax_scale) + + +class SelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, qkv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, S) + """ + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + causal = self.causal if causal is None else causal + q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, + device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = self.drop(attention) + output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + return output + + +class CrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, q, kv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, Sk) + """ + batch_size, seqlen_q = q.shape[0], q.shape[1] + causal = self.causal if causal is None else causal + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + if kv.shape[3] != q.shape[2]: # MQA/GQA + kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) + k, v = kv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, + device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0, + device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = self.drop(attention) + output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + return output + + +class LinearResidual(nn.Linear): + """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input), input + + +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) + """ + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, inference_params.max_sequence_len, 2, + num_heads, head_dim, dtype=kv.dtype, device=kv.device + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + # Copy key and values. + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize', + packsize=packsize).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange( + kv[:, :, 1], 'b s h d -> b h s d' + ) + return kv + + +def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotary_emb_dim, + rotary_emb_base, kv=None, rotary_emb_interleaved=False): + """ + qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just + q of shape (batch_size, 1, nheads, head_dim) + kv: (batch_size, 1, 2, nheads_kv, head_dim) + """ + assert inference_params.fused_ft_kernel + assert ft_attention is not None + if kv is None: + q, k, v = rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1) + else: + q = rearrange(qkv, 'b 1 h d -> b h d') + k, v = rearrange(kv, 'b 1 two h d -> b two h d').unbind(dim=1) + batch_start = inference_params.batch_size_offset + batch_end = batch_start + q.shape[0] + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end] + if inference_params.lengths_per_sample is not None else None) + context = ft_attention.single_query_attention( + q, k, v, + k_cache[batch_start:batch_end], + v_cache[batch_start:batch_end], + lengths_per_sample, + None, # rotary_cos_ + None, # rotary_sin_ + None, # nnz_head_idx + inference_params.sequence_len_offset, + rotary_emb_dim, rotary_emb_base, + not rotary_emb_interleaved # neox_rotary_style + ) + return rearrange(context, 'b h d -> b 1 h d') + + +class MHA(nn.Module): + """Multi-head self-attention and cross-attention + """ + + def __init__(self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False, + qkv_proj_bias=True, out_proj_bias=True, + dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False, + rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, + rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False, + return_residual=False, checkpointing=False, device=None, dtype=None) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.cross_attn = cross_attn + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.return_residual = return_residual + self.checkpointing = checkpointing + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + kv_dim = 2 * self.head_dim * self.num_heads_kv + + if self.rotary_emb_dim > 0: + assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' + assert RotaryEmbedding is not None, 'rotary_emb is not installed' + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, device=device) + + if fused_bias_fc and FusedDense is None: + raise ImportError('fused_dense is not installed') + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + linear_resid_cls = (LinearResidual if not fused_bias_fc + else partial(FusedDense, return_residual=True)) + wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + if not self.cross_attn: + self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + else: + self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) + if self.dwconv: + if self.num_heads_kv == self.num_heads: + self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2, + groups=qkv_dim) + else: + self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2, + groups=embed_dim) + self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, + groups=kv_dim) + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) + self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + if not fused_ft_kernel: + return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim, + dtype=dtype, device=device) + else: + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if dtype == torch.float32 else 8 + assert self.head_dim % packsize == 0 + k_cache = torch.empty(batch_size, self.num_heads_kv, self.head_dim // packsize, + max_seqlen, packsize, dtype=dtype, device=device) + v_cache = torch.empty(batch_size, self.num_heads_kv, max_seqlen, self.head_dim, + dtype=dtype, device=device) + return k_cache, v_cache + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) + """ + assert not self.dwconv, 'Generation does not support dwconv yet' + assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None): + """ + qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just + q of shape (batch_size, 1, nheads, head_dim) + kv: (batch_size, 1, 2, nheads_kv, head_dim) + """ + rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 + return _apply_rotary_single_query_attention( + qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv, + rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + ) + + def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None, + mixer_subset=None, inference_params=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if + cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total + is the is the sum of the sequence lengths in the batch. + x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into x. Only applicable when using + FlashAttention. + max_seqlen: int. Maximum sequence length in the batch. + key_padding_mask: boolean mask, True means to keep, False means to mask out. + (batch, seqlen). Only applicable when not using FlashAttention. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + if cu_seqlens is not None: + assert max_seqlen is not None + assert key_padding_mask is None + assert self.use_flash_attn + assert not self.dwconv + assert self.rotary_emb_dim == 0 + if key_padding_mask is not None: + assert cu_seqlens is None + assert max_seqlen is None + assert not self.use_flash_attn + if inference_params is not None: + assert key_padding_mask is None + assert cu_seqlens is None and max_seqlen is None + assert not self.dwconv + + kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs} + if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs}) + seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset + if not self.cross_attn and self.num_heads_kv == self.num_heads: + assert x_kv is None and mixer_subset is None + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + if self.dwconv: + qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) + if (inference_params is None or inference_params.sequence_len_offset == 0 + or not inference_params.fused_ft_kernel): + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, + **kwargs) + else: + q = qkv[:, :, 0] + kv = self._update_kv_cache(qkv[:, :, 1:], inference_params) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + context = self._apply_rotary_single_query_attention(qkv, inference_params) + else: + if self.cross_attn: + if not self.return_residual: + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) + else: + if x_kv is not None: + kv, x_kv = self.Wkv(x_kv) + else: + kv, x = self.Wkv(x) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + else: + assert self.num_heads_kv != self.num_heads + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + q = qkv[..., :self.num_heads * self.head_dim] + kv = qkv[..., self.num_heads * self.head_dim:] + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + kv = rearrange(kv, '... (two hkv d) -> ... two hkv d', two=2, d=self.head_dim) + if self.dwconv: + q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + if (inference_params is None or inference_params.sequence_len_offset == 0 + or not inference_params.fused_ft_kernel): + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset) + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, + **kwargs) + else: + kv = self._update_kv_cache(kv, inference_params) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) + out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) + return out if not self.return_residual else (out, x) + + +class ParallelMHA(nn.Module): + """Multi-head self-attention and cross-attention + """ + + def __init__(self, embed_dim, num_heads, process_group, num_heads_kv=None, + qkv_proj_bias=True, out_proj_bias=True, + dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, + rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None, + rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False, + sequence_parallel=True, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + self.process_group = process_group + self.world_size = process_group.size() if process_group is not None else 1 + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + self.num_heads_per_rank = num_heads // self.world_size + self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size + assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + assert self.num_heads_kv % self.world_size == 0, "num_heads_kv must be divisible by world_size" + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + kv_dim = 2 * self.head_dim * self.num_heads_kv + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, 'rotary_emb is not installed' + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, device=device) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError('fused_dense is not installed') + self.Wqkv = ColumnParallelLinear(embed_dim, qkv_dim, process_group, + bias=qkv_proj_bias, + sequence_parallel=sequence_parallel, **factory_kwargs) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) + self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) + self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, + bias=out_proj_bias, + sequence_parallel=sequence_parallel, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + if not fused_ft_kernel: + return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv_per_rank, + self.head_dim, dtype=dtype, device=device) + else: + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if dtype == torch.float32 else 8 + assert self.head_dim % packsize == 0 + k_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, + self.head_dim // packsize, + max_seqlen, packsize, dtype=dtype, device=device) + v_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, max_seqlen, + self.head_dim, dtype=dtype, device=device) + return k_cache, v_cache + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim) + """ + assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor' + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None): + """ + qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just + q of shape (batch_size, 1, nheads, head_dim) + kv: (batch_size, 1, 2, nheads_kv, head_dim) + """ + rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0 + return _apply_rotary_single_query_attention( + qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv, + rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + ) + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + qkv = self.Wqkv(x) + if seqlen is not None: + qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) + seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset + if self.num_heads_kv == self.num_heads: + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim) + if (inference_params is None or inference_params.sequence_len_offset == 0 + or not inference_params.fused_ft_kernel): + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + else: + q = qkv[:, :, 0] + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + context = self._apply_rotary_single_query_attention(qkv, inference_params) + else: + q = rearrange(qkv[..., :self.num_heads_per_rank * self.head_dim], + "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(qkv[..., self.num_heads_per_rank * self.head_dim:], + "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + if (inference_params is None or inference_params.sequence_len_offset == 0 + or not inference_params.fused_ft_kernel): + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset) + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv, + **kwargs) + else: + kv = self._update_kv_cache(kv, inference_params) + # If we're processing the prompt, causal=None (use self.causal). + # If we're decoding, then causal=False. + causal = None if inference_params.sequence_len_offset == 0 else False + context = self.inner_cross_attn(q, kv, causal=causal) + else: + context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv) + context = rearrange(context, 'b s h d -> b s (h d)') + if seqlen is not None: + context = rearrange(context, 'b s d -> (b s) d') + out = self.out_proj(context) + return out diff --git a/pkgs/xformers/_flash_attn/modules/mlp.py b/pkgs/xformers/_flash_attn/modules/mlp.py new file mode 100644 index 0000000..f0c9acc --- /dev/null +++ b/pkgs/xformers/_flash_attn/modules/mlp.py @@ -0,0 +1,86 @@ +# Copyright (c) 2022, Tri Dao. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear +except ImportError: + ColumnParallelLinear, RowParallelLinear = None, None + +try: + from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP +except ImportError: + FusedMLP, ParallelFusedMLP = None, None + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, + bias1=True, bias2=True, return_residual=False, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y if not self.return_residual else (y, x) + + +class ParallelMLP(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, + process_group: ProcessGroup = None, sequence_parallel=True, + bias1=True, bias2=True, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + assert ColumnParallelLinear is not None, "Need to install fused_dense" + assert RowParallelLinear is not None, "Need to install fused_dense" + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, bias=bias1, + sequence_parallel=sequence_parallel, **factory_kwargs) + self.activation = activation + self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, bias=bias2, + sequence_parallel=sequence_parallel, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y + + +class GatedMlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.sigmoid, + bias1=True, bias2=True, multiple_of=256, return_residual=False, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or int(8 * in_features / 3) + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias1, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + if self.activation == F.sigmoid: # Special case for GLU + y = F.glu(y, dim=-1) + else: + y, gate = y.chunk(2, dim=-1) + y = y * self.activation(gate) + y = self.fc2(y) + return y if not self.return_residual else (y, x) diff --git a/pkgs/xformers/_flash_attn/ops/__init__.py b/pkgs/xformers/_flash_attn/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers/_flash_attn/ops/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..c1fe20a Binary files /dev/null and b/pkgs/xformers/_flash_attn/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/ops/__pycache__/activations.cpython-310.pyc b/pkgs/xformers/_flash_attn/ops/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000..f50f353 Binary files /dev/null and b/pkgs/xformers/_flash_attn/ops/__pycache__/activations.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-310.pyc b/pkgs/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-310.pyc new file mode 100644 index 0000000..0354c3d Binary files /dev/null and b/pkgs/xformers/_flash_attn/ops/__pycache__/fused_dense.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-310.pyc b/pkgs/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-310.pyc new file mode 100644 index 0000000..2bac52b Binary files /dev/null and b/pkgs/xformers/_flash_attn/ops/__pycache__/layer_norm.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-310.pyc b/pkgs/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-310.pyc new file mode 100644 index 0000000..e43f5aa Binary files /dev/null and b/pkgs/xformers/_flash_attn/ops/__pycache__/rms_norm.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/ops/activations.py b/pkgs/xformers/_flash_attn/ops/activations.py new file mode 100644 index 0000000..3944145 --- /dev/null +++ b/pkgs/xformers/_flash_attn/ops/activations.py @@ -0,0 +1,99 @@ +# Copied from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/model/layers/activations.py +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +# 1/sqrt(2*pi)-> 0.3989423 +# 1/sqrt(2) -> 0.70710678 +# sqrt(2/pi) -> 0.79788456 + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def bias_gelu(y, bias): + x = bias + y + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=y.dtype) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def bias_gelu_back(g, y, bias): + """Assume that y has shape (B, D) and bias has shape (D) + """ + x = bias + y + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + grad_y = ff * g + return grad_y.to(dtype=y.dtype), grad_y.sum(dim=(0), dtype=bias.dtype) + + +class GeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, bias): + ctx.save_for_backward(input, bias) + return bias_gelu(input, bias) + + @staticmethod + def backward(ctx, grad_output): + input, bias = ctx.saved_tensors + tmp = bias_gelu_back(grad_output, input, bias) + return tmp, tmp + + +bias_gelu_impl = GeLUFunction.apply + +# this function is tanh approximation of gelu +# actual gelu is: +# x * 0.5 * (1.0 + torch.erf(x * 0.70710678)) +@torch.jit.script +def gelu_fwd(x): + return (x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))).to(dtype=x.dtype) + +# gradient of tanh approximation of gelu +# gradient of actual gelu is: +# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x) +@torch.jit.script +def gelu_bwd(g, x): + tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return (ff * g).to(dtype=x.dtype) + + +class FastGeLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input): + ctx.save_for_backward(input) + return gelu_fwd(input) + + @staticmethod + def backward(ctx, grad_output): + input, = ctx.saved_tensors + tmp = gelu_bwd(grad_output, input) + return tmp + +fast_gelu_impl = FastGeLUFunction.apply + + +@torch.jit.script +def relu_bwd(g, x): + return torch.where(x >= 0, g, 0.0).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_fwd(x): + r = F.relu(x) + return (r * r).to(dtype=x.dtype) + + +@torch.jit.script +def sqrelu_bwd(g, x): + return (2.0 * g * F.relu(x)).to(dtype=x.dtype) diff --git a/pkgs/xformers/_flash_attn/ops/fused_dense.py b/pkgs/xformers/_flash_attn/ops/fused_dense.py new file mode 100644 index 0000000..d69d683 --- /dev/null +++ b/pkgs/xformers/_flash_attn/ops/fused_dense.py @@ -0,0 +1,527 @@ +# Copyright (c) 2023, Tri Dao. +# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py +# We make it work with pytorch amp and with bfloat16. +# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py +from typing import Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from torch.cuda.amp import custom_bwd, custom_fwd + +# import fused_dense_cuda # from apex +import fused_dense_lib as fused_dense_cuda + +from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd +from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw +from flash_attn.utils.distributed import reduce_scatter, all_reduce + + +class FusedDenseFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, return_residual=False, process_group=None, + sequence_parallel=True): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + if ctx.return_residual: + grad_input, = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + weight, = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_output, weight.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, weight) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2] + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None, None + + +def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, + return_residual: bool = False, process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True): + dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] + or (x.dtype == torch.float32 and torch.is_autocast_enabled())) + if x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and dtype_eligible: + return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group, + sequence_parallel) + else: + assert process_group is None + out = F.linear(x, weight, bias) + return out if not return_residual else (out, x) + + +class FusedDense(nn.Linear): + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + return_residual: bool = False, device=None, dtype=None) -> None: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.return_residual = return_residual + + def forward(self, x, process_group=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul. + """ + return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual, + process_group=process_group) + + +class ColumnParallelLinear(nn.Linear): + + def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, + bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % world_size != 0: + raise ValueError(f'out_features ({out_features}) must be divisible by ' + f'world_size ({world_size})') + super().__init__(in_features, out_features // world_size, bias=bias, + device=device, dtype=dtype) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group, + sequence_parallel=self.sequence_parallel) + + +class RowParallelLinear(nn.Linear): + + def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup, + bias: bool = True, sequence_parallel=True, device=None, dtype=None) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % world_size != 0: + raise ValueError(f'in_features ({in_features}) must be divisible by ' + f'world_size ({world_size})') + # Only rank 0 will have bias + super().__init__(in_features // world_size, out_features, bias=bias and rank == 0, + device=device, dtype=dtype) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = fused_dense_func(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + +class FusedMLPFunc(torch.autograd.Function): + + @staticmethod + @custom_fwd + def forward(ctx, x, weight1, bias1, weight2, bias2, activation='gelu_approx', save_pre_act=True, + return_residual=False, checkpoint_lvl=0, heuristic=0, process_group=None, + sequence_parallel=True): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather of x before doing the matmul. + If sequence_parallel=False, then the input is already gathered. + + checkpoint_lvl: + 0: no recomputation in the bwd + 1: recompute gelu_out / relu_out in the bwd + 2: recompute pre_act and gelu_out / relu_out in the bwd + """ + assert -1 <= heuristic <= 4 + assert activation in ['gelu_approx', 'relu', 'sqrelu'] + if activation == 'sqrelu': + assert heuristic == -1 + if not save_pre_act: + checkpoint_lvl = 2 + assert checkpoint_lvl in [0, 1, 2] + ctx.return_residual = return_residual + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + ctx.checkpoint_lvl = checkpoint_lvl + ctx.activation = activation + ctx.heuristic = heuristic + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_gpu_dtype() + weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]] + bias1 = bias1.to(dtype=dtype) if bias1 is not None else None + bias2 = bias2.to(dtype=dtype) if bias2 is not None else None + weight1 = weight1.contiguous() + bias1 = bias1.contiguous() if bias1 is not None else None + weight2 = weight2.contiguous() + bias2 = bias2.contiguous() if bias2 is not None else None + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight1.shape, *weight2.shape) > 65535 * 32: + raise RuntimeError('fused_dense only supports matrix dims <= 2M') + if heuristic == -1: + pre_act = F.linear(total_x, weight1, bias1) + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else (sqrelu_fwd if activation == 'sqrelu' else F.relu)) + with torch.jit.fuser('fuser2'): + output1 = activation_fn(pre_act) + # This is before adding bias1 + # pre_act = F.linear(total_x.reshape(batch_dim, n), weight1) + # with torch.jit.fuser('fuser2'): + # output1 = bias_gelu(pre_act, bias1) + else: + is_gelu = activation == 'gelu_approx' + output1, *rest = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, n), weight1, bias1, is_gelu, save_pre_act, heuristic + ) + if save_pre_act: + pre_act = rest[0] + output2 = F.linear(output1, weight2, bias2) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'): + # For RELU the pre_act is very small (just a bit-mask) so we just save it + ctx.save_for_backward(x, weight1, weight2, pre_act, output1) + elif checkpoint_lvl == 1: + ctx.save_for_backward(x, weight1, weight2, pre_act) + elif checkpoint_lvl == 2: + ctx.save_for_backward(x, weight1, weight2, bias1) + output2 = output2.reshape(*batch_shape, output2.shape[-1]) + return output2 if not return_residual else (output2, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + grad_output = grad_output.contiguous() + checkpoint_lvl = ctx.checkpoint_lvl + activation = ctx.activation + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else (sqrelu_fwd if activation == 'sqrelu' else F.relu)) + if ctx.return_residual: + grad_input, = args + grad_input = grad_input.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + x, weight1, weight2, *rest = ctx.saved_tensors + if process_group is None or not sequence_parallel: + total_x = x + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + if checkpoint_lvl in [0, 1]: + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + if checkpoint_lvl == 0 or (checkpoint_lvl == 1 and activation == 'relu'): + pre_act, output1 = rest + elif checkpoint_lvl == 1: + pre_act, = rest + with torch.jit.fuser('fuser2'): + output1 = activation_fn(pre_act) + elif checkpoint_lvl == 2: + bias1, = rest + if process_group is not None and sequence_parallel: + total_x, _ = all_gather_raw(x, process_group) + if ctx.heuristic == -1: + pre_act = F.linear(total_x, weight1, bias1) + with torch.jit.fuser('fuser2'): + output1 = activation_fn(pre_act) + else: + output1, pre_act = fused_dense_cuda.linear_act_forward( + total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, + activation == 'gelu_approx', True, ctx.heuristic + ) + + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + output1 = output1.reshape(batch_dim, output1.shape[-1]) + pre_act = pre_act.reshape(batch_dim, pre_act.shape[-1]) + if ctx.needs_input_grad[3]: + grad_weight2, grad_bias2 = fused_dense_cuda.linear_bias_wgrad( + output1, grad_output, ctx.needs_input_grad[4] + ) + else: + grad_weight2 = None + grad_bias2 = grad_output if ctx.needs_input_grad[4] else None + if ctx.heuristic == -1: + # grad_pre_act = matmul_dgelu(grad_output, weight2, pre_act) + grad_output1 = F.linear(grad_output, weight2.t()) + activation_grad_fn = (gelu_bwd if activation == 'gelu_approx' + else (sqrelu_bwd if activation == 'sqrelu' else relu_bwd)) + with torch.jit.fuser('fuser2'): + grad_pre_act = activation_grad_fn(grad_output1, pre_act) + else: + # The cublasLt epilogue has to compute both gelu/relu grad and bias grad, we can't + # just compute gelu/relu grad + grad_pre_act, grad_bias1 = fused_dense_cuda.bias_act_linear_dgrad_bgrad( + weight2, grad_output, pre_act, activation == 'gelu_approx', ctx.heuristic + ) + if not ctx.needs_input_grad[2]: + grad_bias1 = None + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = F.linear(grad_pre_act, weight1.t()) + else: + grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_pre_act, weight1) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.heuristic == -1: + if ctx.needs_input_grad[1]: + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad( + total_x.reshape(batch_dim, total_x.shape[-1]), grad_pre_act, + ctx.needs_input_grad[2] + ) + else: + grad_weight1 = None + grad_bias1 = grad_pre_act if ctx.needs_input_grad[2] else None + else: + if ctx.needs_input_grad[1]: + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight1 = F.linear(grad_pre_act.t(), + total_x.reshape(batch_dim, total_x.shape[-1]).t()) + else: + grad_weight1 = None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, + None, None, None, None, None, None, None) + + +def fused_mlp_func( + x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None, + bias2: Optional[Tensor] = None, activation: str = 'gelu_approx', + save_pre_act: bool = True, return_residual: bool = False, + checkpoint_lvl: int = 0, heuristic: int = 0, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True +): + assert activation in ['gelu_approx', 'relu', 'sqrelu'] + dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16] + or (x.dtype == torch.float32 and torch.is_autocast_enabled())) + # If we save pre-activation, dimension must be divisible by 128 (relu) or 8 (gelu) + dim_eligible = not save_pre_act or (x.shape[-1] % (128 if activation == 'relu' else 8) == 0) + if (x.is_cuda and weight1.is_cuda and weight2.is_cuda and (bias1 is None or bias1.is_cuda) + and (bias2 is None or bias2.is_cuda) and dtype_eligible and dim_eligible): + return FusedMLPFunc.apply( + x, weight1, bias1, weight2, bias2, activation, save_pre_act, return_residual, + checkpoint_lvl, heuristic, process_group, sequence_parallel + ) + else: + assert process_group is None + pre_act = F.linear(x, weight1, bias1) + activation_fn = (partial(F.gelu, approximate='tanh') if activation == 'gelu_approx' + else partial(F.relu, inplace=True)) + output1 = activation_fn(pre_act) + output2 = F.linear(output1, weight2, bias2) + return output2 if not return_residual else (output2, x) + + +class FusedMLP(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, bias1=True, + bias2=True, activation='gelu_approx', return_residual=False, + checkpoint_lvl=0, heuristic='auto', device=None, dtype=None): + """ + If process_group is not None, we're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + For H100, we set heuristic=-1 for both fp16 and bf16 as the fused cuBlasLt implementation + is slower than the unfused version. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ['gelu_approx', 'relu', 'sqrelu'] + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.activation = activation + self.return_residual = return_residual + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic if activation != 'sqrelu' else -1 + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x, process_group=None): + dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() + if self.heuristic == 'auto': + if self.activation == 'gelu_approx': + if torch.cuda.get_device_capability('cuda') == (9, 0): + heuristic = -1 + else: + cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( + x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, + activation=self.activation, save_pre_act=self.training, + return_residual=self.return_residual, checkpoint_lvl=self.checkpoint_lvl, + heuristic=heuristic, process_group=process_group + ) + if self.return_residual: + out, x = out + if process_group is not None: + out = reduce_scatter(out, process_group) + return out if not self.return_residual else (out, x) + + +class ParallelFusedMLP(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, + activation='gelu_approx', process_group: ProcessGroup = None, + bias1=True, bias2=True, sequence_parallel=True, checkpoint_lvl=0, heuristic='auto', + device=None, dtype=None): + """ + process_group is required. We're doing Tensor Parallel with sequence parallelism: + we do an all_gather of x before doing the matmul, gelu, then matmul. + Finally we do a reduce_scatter of the output. + + checkpoint_lvl (increasing lvl means slower but more memory saving): + 0: no recomputation in the bwd + 1: recompute gelu_out in the bwd + 2: recompute pre_act and gelu_out in the bwd + heuristic: + -1: don't fuse gemm + gelu (separate kernel) + 0..4: use this heuristic for the algo section in the fused gemm + gelu + 'auto': heuristic will be picked automatically: + For CUDA >= 11.8, we set heuristic=0 for both fp16 and bf16 for best perf. + For CUDA <= 11.7, we set heuristic=1 for fp16 and heuristic=-1 for bf16. + """ + assert checkpoint_lvl in [0, 1, 2] + assert activation in ['gelu_approx', 'relu', 'sqrelu'] + assert process_group is not None + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + self.activation = activation + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.checkpoint_lvl = checkpoint_lvl + self.heuristic = heuristic if activation != 'sqrelu' else -1 + self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group, + bias=bias1, **factory_kwargs) + self.fc2 = RowParallelLinear(hidden_features, out_features, process_group, + bias=bias2, **factory_kwargs) + + def forward(self, x): + dtype = x.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() + if self.heuristic == 'auto': + if self.activation == 'gelu_approx': + cuda_ver = tuple(map(int, torch.version.cuda.split('.'))) + heuristic = 0 if cuda_ver >= (11, 8) else (1 if dtype == torch.float16 else -1) + else: + heuristic = 0 + else: + heuristic = self.heuristic + out = fused_mlp_func( + x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias, + activation=self.activation, save_pre_act=self.training, + checkpoint_lvl=self.checkpoint_lvl, heuristic=heuristic, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel + ) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) diff --git a/pkgs/xformers/_flash_attn/ops/layer_norm.py b/pkgs/xformers/_flash_attn/ops/layer_norm.py new file mode 100644 index 0000000..c42ea90 --- /dev/null +++ b/pkgs/xformers/_flash_attn/ops/layer_norm.py @@ -0,0 +1,375 @@ +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py + +import torch +from torch.nn import init + +import dropout_layer_norm + + +def maybe_align(x, alignment_in_bytes=16): + """Assume that x already has last dim divisible by alignment_in_bytes + """ + # TD [2023-07-04] I'm not 100% sure that clone will align the memory + # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 + return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() + + +def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p, + epsilon, residual_in_fp32=False, is_rms_norm=False): + """ Assume that arguments are contiguous and aligned to 16 bytes + """ + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, residualmat, gamma, beta, rowscale, colscale, None, None, dropout_p, epsilon, + 1.0, 0, None, residual_in_fp32, is_rms_norm + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, + dropout_p, has_residual, is_rms_norm=False): + """ Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + if colscale is not None: + assert x0 is not None, 'x0 is required to compute the gradient of colscale' + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, rowscale, colscale, None, None, + dropout_p, 1.0, 0, has_residual, is_rms_norm + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset, + out_subset, dropout_p, epsilon, rowscale_const, + out_numrows, residual_in_fp32=False, is_rms_norm=False): + """ Assume that arguments are contiguous and aligned to 16 bytes + """ + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, residualmat, gamma, beta, None, colscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, None, residual_in_fp32, is_rms_norm + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, + x0_subset, out_subset, dropout_p, rowscale_const, + x0_numrows, has_residual, is_rms_norm=False): + """ Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(-1, hidden_size) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + if colscale is not None: + assert x0 is not None, 'x0 is required to compute the gradient of colscale' + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, dxmat, xmat, x0mat, dmask, mu, rsigma, gamma, None, colscale, x0_subset, out_subset, + dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_parallel_residual_forward( + x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, + epsilon, residual_in_fp32=False, is_rms_norm=False +): + """ Assume that arguments are contiguous and aligned to 16 bytes + """ + hidden_size = gamma0.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( + x0mat, x1mat, residualmat, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, + None, residual_in_fp32, is_rms_norm + ) + # dmask0 and dmask1 are None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma + + +def _dropout_add_layer_norm_parallel_residual_backward( + dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, + dropout_p, has_x1, has_residual, is_rms_norm=False +): + """ Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + """ + hidden_size = gamma0.numel() + xmat = x.view((-1, hidden_size)) + dz0mat = dz0.view(xmat.shape) + dz1mat = dz1.view(xmat.shape) if dz1 is not None else None + dxmat = dx.view(xmat.shape) if dx is not None else None + dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1, *rest = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( + dz0mat, dz1mat, dxmat, xmat, dmask0, dmask1, mu, rsigma, gamma0, gamma1, + dropout_p, has_x1, has_residual, is_rms_norm + ) + # dresidualmat is None if not has_residual + return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 + + +class DropoutAddLayerNormFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, + residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, + residual_in_fp32, is_rms_norm + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + ctx.save_for_backward(xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + if not return_dmask: + return (zmat.view(x0.shape) if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape))) + else: + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return ((zmat.view(x0.shape), dmask) if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape), dmask)) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, + ctx.is_rms_norm + ) + dx0 = dx0mat.view(x.shape) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, None, dcolscale, None, + None, None, None, None, None) + + +class DropoutAddLayerNormSubsetFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32=False, + prenorm=False, is_rms_norm=False, return_dmask=False): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( + x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32, is_rms_norm + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + x_shape = (-1, *x0.shape[1:]) + ctx.save_for_backward(xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, + x0_subset, out_subset) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.rowscale_const = rowscale_const + ctx.x0_numrows = x0.shape[:-1].numel() + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + z_shape = (-1, *x0.shape[1:]) + if not return_dmask: + return (zmat.view(z_shape) if not prenorm + else (zmat.view(z_shape), xmat.view(x0.shape))) + else: + z = zmat.view(z_shape) + dmask = (dmask.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask) + return ((z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask)) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, + ctx.rowscale_const, ctx.x0_numrows, has_residual, ctx.is_rms_norm + ) + dx0 = dx0mat.view(-1, *x.shape[1:]) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return (dx0, dresidual, dgamma, dbeta if ctx.has_beta else None, dcolscale, None, None, + None, None, None, None, None, None, None, None) + + +class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, + residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): + x0 = maybe_align(x0.contiguous(), 16) + x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma0 = maybe_align(gamma0.contiguous(), 16) + beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None + gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None + beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None + z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward( + x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, + residual_in_fp32, is_rms_norm + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_x1 = x1 is not None + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta0 is not None + z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) + if not return_dmask: + return z if not prenorm else (*z, xmat.view(x0.shape)) + else: + dmask0 = (dmask0.view(x0.shape) if dropout_p > 0. + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + dmask1 = (dmask1.view(x0.shape) if dropout_p > 0. and x1 is not None + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device)) + ctx.mark_non_differentiable(dmask0) + ctx.mark_non_differentiable(dmask1) + return (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) + + @staticmethod + def backward(ctx, dz0, dz1, *args): + dz0 = maybe_align(dz0.contiguous(), 16) # this happens! + dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_x1 = ctx.has_x1 + has_residual = ctx.has_residual + dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 = _dropout_add_layer_norm_parallel_residual_backward( + dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, + has_residual, ctx.is_rms_norm + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + return (dx0, dx1, dresidual, dgamma0, dbeta0 if ctx.has_beta else None, dgamma1, + dbeta1 if ctx.has_beta else None, None, None, None, None, None, None) + + +def layer_norm(x, weight, bias, epsilon): + return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) + + +def dropout_add_layer_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, + layerscale=None, prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, + False, return_dropout_mask + ) + + +def dropout_add_layer_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, + x0_subset=None, out_subset=None, rowscale_const=1.0, + out_numrows=0, prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32, prenorm, False, return_dropout_mask + ) + + +def dropout_add_layer_norm_parallel_residual( + x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, prenorm=False, + residual_in_fp32=False, return_dropout_mask=False +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, + False, return_dropout_mask + ) + + +class DropoutAddLayerNorm(torch.nn.Module): + def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x0, residual=None): + return dropout_add_layer_norm(x0, residual, self.weight, self.bias, + self.p if self.training else 0.0, self.eps, + prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) diff --git a/pkgs/xformers/_flash_attn/ops/rms_norm.py b/pkgs/xformers/_flash_attn/ops/rms_norm.py new file mode 100644 index 0000000..bf551bc --- /dev/null +++ b/pkgs/xformers/_flash_attn/ops/rms_norm.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py + +import torch +from torch.nn import init + +from flash_attn.ops.layer_norm import DropoutAddLayerNormFn, DropoutAddLayerNormSubsetFn +from flash_attn.ops.layer_norm import DropoutAddLayerNormParallelResidualFn + + +def rms_norm(x, weight, epsilon): + return DropoutAddLayerNormFn.apply(x, None, weight, None, None, None, 0.0, epsilon, False, + False, True) + + +def dropout_add_rms_norm(x0, residual, weight, bias, dropout_p, epsilon, rowscale=None, + layerscale=None, prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, residual, weight, bias, rowscale, layerscale, dropout_p, epsilon, residual_in_fp32, prenorm, + True, return_dropout_mask + ) + + +def dropout_add_rms_norm_subset(x0, residual, weight, bias, dropout_p, epsilon, layerscale=None, + x0_subset=None, out_subset=None, rowscale_const=1.0, + out_numrows=0, prenorm=False, residual_in_fp32=False, + return_dropout_mask=False): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, residual, weight, bias, layerscale, x0_subset, out_subset, dropout_p, epsilon, + rowscale_const, out_numrows, residual_in_fp32, prenorm, True, return_dropout_mask + ) + + +def dropout_add_rms_norm_parallel_residual( + x0, x1, residual, weight0, bias0, weight1, bias1, + dropout_p, epsilon, prenorm=False, residual_in_fp32=False, return_dropout_mask=False +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, x1, residual, weight0, bias0, weight1, bias1, dropout_p, epsilon, residual_in_fp32, prenorm, + True, return_dropout_mask + ) + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x): + return rms_norm(x, self.weight, self.eps) + + +class DropoutAddRMSNorm(torch.nn.Module): + def __init__(self, hidden_size, prenorm=False, p=0.0, eps=1e-5, residual_in_fp32=False, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + + def forward(self, x0, residual=None): + return dropout_add_rms_norm(x0, residual, self.weight, None, + self.p if self.training else 0.0, self.eps, + prenorm=self.prenorm, residual_in_fp32=self.residual_in_fp32) diff --git a/pkgs/xformers/_flash_attn/utils/__init__.py b/pkgs/xformers/_flash_attn/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkgs/xformers/_flash_attn/utils/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/_flash_attn/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6f8aa0a Binary files /dev/null and b/pkgs/xformers/_flash_attn/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/utils/__pycache__/benchmark.cpython-310.pyc b/pkgs/xformers/_flash_attn/utils/__pycache__/benchmark.cpython-310.pyc new file mode 100644 index 0000000..8f7b605 Binary files /dev/null and b/pkgs/xformers/_flash_attn/utils/__pycache__/benchmark.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/utils/__pycache__/distributed.cpython-310.pyc b/pkgs/xformers/_flash_attn/utils/__pycache__/distributed.cpython-310.pyc new file mode 100644 index 0000000..0120181 Binary files /dev/null and b/pkgs/xformers/_flash_attn/utils/__pycache__/distributed.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/utils/__pycache__/generation.cpython-310.pyc b/pkgs/xformers/_flash_attn/utils/__pycache__/generation.cpython-310.pyc new file mode 100644 index 0000000..21eafe4 Binary files /dev/null and b/pkgs/xformers/_flash_attn/utils/__pycache__/generation.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/utils/__pycache__/pretrained.cpython-310.pyc b/pkgs/xformers/_flash_attn/utils/__pycache__/pretrained.cpython-310.pyc new file mode 100644 index 0000000..75dad7b Binary files /dev/null and b/pkgs/xformers/_flash_attn/utils/__pycache__/pretrained.cpython-310.pyc differ diff --git a/pkgs/xformers/_flash_attn/utils/benchmark.py b/pkgs/xformers/_flash_attn/utils/benchmark.py new file mode 100644 index 0000000..c74b1a8 --- /dev/null +++ b/pkgs/xformers/_flash_attn/utils/benchmark.py @@ -0,0 +1,146 @@ +# Copyright (c) 2022, Tri Dao. +""" Useful functions for writing test code. """ + +import torch +import torch.utils.benchmark as benchmark + + +def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): + """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """ + if verbose: + print(desc, '- Forward pass') + def fn_amp(*inputs, **kwinputs): + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + for _ in range(repeats): # warmup + fn_amp(*inputs, **kwinputs) + t = benchmark.Timer( + stmt='fn_amp(*inputs, **kwinputs)', + globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): + """ Use Pytorch Benchmark on the backward pass of an arbitrary function. """ + if verbose: + print(desc, '- Backward pass') + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError('Grad shape does not match output shape') + for _ in range(repeats): # warmup + y.backward(grad, retain_graph=True) + t = benchmark.Timer( + stmt='y.backward(grad, retain_graph=True)', + globals={'y': y, 'grad': grad}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + if verbose: + print(desc, '- Forward + Backward pass') + def f(grad, *inputs, **kwinputs): + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + y = fn(*inputs, **kwinputs) + if type(y) is tuple: + y = y[0] + if grad is None: + grad = torch.randn_like(y) + else: + if grad.shape != y.shape: + raise RuntimeError('Grad shape does not match output shape') + y.backward(grad, retain_graph=True) + for _ in range(repeats): # warmup + f(grad, *inputs, **kwinputs) + t = benchmark.Timer( + stmt='f(grad, *inputs, **kwinputs)', + globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + if verbose: + print(m) + return t, m + + +def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False, + amp_dtype=torch.float16, **kwinputs): + """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ + return ( + benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, + amp=amp, amp_dtype=amp_dtype, **kwinputs), + benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, + amp=amp, amp_dtype=amp_dtype, **kwinputs), + benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, + amp=amp, amp_dtype=amp_dtype, **kwinputs), + ) + + +def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False, + amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs): + """ Wrap benchmark functions in Pytorch profiler to see CUDA information. """ + if backward: + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + g = torch.randn_like(fn(*inputs, **kwinputs)) + for _ in range(30): # Warm up + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + # fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g) + out = fn(*inputs, **kwinputs) + # Backward should be done outside autocast + if backward: + out.backward(g) + activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA] + with torch.profiler.profile( + activities=activities, + record_shapes=True, + # profile_memory=True, + with_stack=True, + ) as prof: + with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp): + if backward: + for x in inputs: + if isinstance(x, torch.Tensor): + x.grad = None + out = fn(*inputs, **kwinputs) + if backward: out.backward(g) + if verbose: + # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) + print(prof.key_averages().table(row_limit=50)) + if trace_filename is not None: + prof.export_chrome_trace(trace_filename) + + +def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + fn(*inputs, **kwinputs) + torch.cuda.synchronize() + mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) + if verbose: + print(f'{desc} max memory: {mem}GB') + torch.cuda.empty_cache() + return mem diff --git a/pkgs/xformers/_flash_attn/utils/distributed.py b/pkgs/xformers/_flash_attn/utils/distributed.py new file mode 100644 index 0000000..09578e3 --- /dev/null +++ b/pkgs/xformers/_flash_attn/utils/distributed.py @@ -0,0 +1,127 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +# Raw operation, does not support autograd, but does support async +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + output = torch.empty(world_size * input_.shape[0], *input_.shape[1:], + dtype=input_.dtype, device=input_.device) + handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(), + group=process_group, async_op=async_op) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:], + dtype=input_.dtype, device=input_.device) + handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(), + group=process_group, + async_op=async_op) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class AllGatherFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_gather_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +all_gather = AllGatherFunc.apply + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = reduce_scatter_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _shared_params=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + pamams_shared = {name: p for name, p in model.named_parameters() + if getattr(p, '_shared_params', False)} + for _, p in sorted(pamams_shared.items()): + with torch.no_grad(): + # Broadcast needs src to be global rank, not group rank + torch.distributed.broadcast( + p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group + ) + + +# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 +def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _sequence_parallel=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + params_seqparallel = {name: p for name, p in model.named_parameters() + if getattr(p, '_sequence_parallel', False)} + grads = [p.grad for _, p in sorted(params_seqparallel.items())] + if grads: + with torch.no_grad(): + coalesced = torch._utils._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=process_group) + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) diff --git a/pkgs/xformers/_flash_attn/utils/generation.py b/pkgs/xformers/_flash_attn/utils/generation.py new file mode 100644 index 0000000..adaa0b5 --- /dev/null +++ b/pkgs/xformers/_flash_attn/utils/generation.py @@ -0,0 +1,302 @@ +# Copyright (c) 2023, Tri Dao. +# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/forward_step.py#L31 +from typing import Optional, Union, Sequence, Callable +import gc +import time + +from dataclasses import dataclass, field +from collections import namedtuple + +import torch +from torch import Tensor +from torch.profiler import profile, record_function, ProfilerActivity + +from einops import rearrange + +from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + max_sequence_len: int + max_batch_size: int + sequence_len_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + fused_ft_kernel: bool = False + lengths_per_sample: Optional[Tensor] = None + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf.""" + if top_p <= 0.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, float('-inf')) + + +def sample(logits, top_k=1, top_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, 'top-p should be in (0, 1].' + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) + ] + else: + logits_top = logits / temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) + + +def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, + eos_token_id=None, teacher_outputs=None, vocab_size=None, tensor_parallel=1, + fused_ft_kernel=False, cg=False, timing=False): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + if cg: + assert fused_ft_kernel + if not hasattr(model, '_decoding_cache'): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, model._decoding_cache, batch_size, seqlen_og, max_length, + tensor_parallel=tensor_parallel + ) + inference_params = model._decoding_cache.inference_params + inference_params.max_sequence_len = max_length + inference_params.max_batch_size = batch_size + inference_params.sequence_len_offset = 0 + else: + inference_params = InferenceParams(max_sequence_len=max_length, max_batch_size=batch_size, + fused_ft_kernel=fused_ft_kernel) + scores = [] + with torch.inference_mode(): + if timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + start = time.time() + logits = model(input_ids, inference_params=inference_params, last_token_only=True).logits + if vocab_size is not None: + logits = logits[..., :vocab_size] + scores.append(logits if not cg else logits.clone()) + if teacher_outputs is None or teacher_output_len <= seqlen_og: + next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) + else: + next_token = teacher_outputs[:, seqlen_og] + sequences = [next_token] + inference_params.sequence_len_offset = seqlen_og + while True: + position_ids = torch.full((batch_size, 1), inference_params.sequence_len_offset, + dtype=torch.long, device=input_ids.device) + if not cg: + logits = model(rearrange(next_token, 'b -> b 1'), position_ids=position_ids, + inference_params=inference_params, last_token_only=True).logits + else: + logits = model._decoding_cache.run(rearrange(next_token, 'b -> b 1'), position_ids, + inference_params.sequence_len_offset) + if vocab_size is not None: + logits = logits[..., :vocab_size] + scores.append(logits if not cg else logits.clone()) + if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1: + next_token = sample(logits, top_k=top_k, temperature=temperature) + else: + next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] + sequences.append(next_token) + inference_params.sequence_len_offset += 1 + if eos_token_id is not None and (next_token == eos_token_id).all(): + break + if inference_params.sequence_len_offset >= max_length - 1: + break + if timing: + if tensor_parallel > 1: + torch.distributed.barrier() + torch.cuda.synchronize() + print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms') + output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput + return output_cls( + sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), + scores=tuple(scores) + ) + + +class GenerationMixin: + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate(self, input_ids, max_length, top_k=1, top_p=0.0, temperature=1.0, + return_dict_in_generate=False, output_scores=False, **kwargs): + output = decode(input_ids, self, max_length, top_k=top_k, top_p=top_p, + temperature=temperature, **kwargs) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +def allocate_inference_cache(max_batch_size, max_seqlen, nheads, headdim, layers: Union[int, Sequence], + device, dtype=torch.float16): + assert dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if dtype == torch.float32 else 8 + assert headdim % packsize == 0 + k_cache_shape = (max_batch_size, nheads, headdim // packsize, max_seqlen, packsize) + v_cache_shape = (max_batch_size, nheads, max_seqlen, headdim) + if isinstance(layers, int): + layers = range(layers) + return {i: (torch.empty(k_cache_shape, device=device, dtype=dtype), + torch.empty(v_cache_shape, device=device, dtype=dtype)) + for i in layers} + + +def seqlen_to_seqlen_type(seqlen: int) -> int: + """Convert sequence length to a seqlen_type. + This is used to determine which cuda graph to use. + Arguments: + seqlen: int + """ + return 0 if seqlen < 32 else (1 if seqlen < 2048 else 2) + + +def seqlen_type_to_max_seqlen(seqlen_type: int) -> int: + assert seqlen_type in [0, 1, 2] + return 32 if seqlen_type == 0 else (2048 if seqlen_type == 1 else 2**32) + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_parallel=1, + dtype=None, n_warmups=2): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ((device, dtype) != (cache.device, cache.dtype) or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + if hasattr(model, 'allocate_inference_cache'): + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + else: + headdim = getattr(model.config, 'head_dim', + model.config.hidden_size // model.config.num_attention_heads) + inf_cache = allocate_inference_cache( + batch_size, max_seqlen, model.config.num_attention_heads // tensor_parallel, headdim, + model.config.num_hidden_layers, device, dtype + ) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_sequence_len=max_seqlen, max_batch_size=batch_size, + sequence_len_offset=seqlen_og, key_value_memory_dict=inf_cache, fused_ft_kernel=True, + lengths_per_sample=lengths_per_sample + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for s_type in range(seqlen_to_seqlen_type(seqlen_og), seqlen_to_seqlen_type(max_seqlen) + 1): + if (batch_size, s_type) not in cache.callables: + max_seqlen_ = min(max(seqlen_og, seqlen_type_to_max_seqlen(s_type)), max_seqlen) + cache.callables[batch_size, s_type] = capture_graph( + model, cache.inference_params, batch_size, max_seqlen_, mempool=cache.mempool, + n_warmups=n_warmups + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size = input_ids.shape[0] + return cache.callables[batch_size, seqlen_to_seqlen_type(seqlen)](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.sequence_len_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph(model, inference_params, batch_size, max_seqlen, mempool=None, n_warmups=2): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, 1), 0, dtype=torch.long, device=device) + sequence_len_offset_og = inference_params.sequence_len_offset + # TD [2023-04-14]: important for correctness of the FT's attention kernel, as seqlen_cpu is + # used to determine the size of smem. Hence seqlen_cpu must be >= lengths_per_sample. + inference_params.sequence_len_offset = max_seqlen - 1 + inference_params.lengths_per_sample[:] = max_seqlen - 1 + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model(input_ids, position_ids=position_ids, inference_params=inference_params, + last_token_only=True).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model(input_ids, position_ids=position_ids, inference_params=inference_params, + last_token_only=True).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits + + inference_params.sequence_len_offset = sequence_len_offset_og + return run diff --git a/pkgs/xformers/_flash_attn/utils/pretrained.py b/pkgs/xformers/_flash_attn/utils/pretrained.py new file mode 100644 index 0000000..c5b7459 --- /dev/null +++ b/pkgs/xformers/_flash_attn/utils/pretrained.py @@ -0,0 +1,37 @@ +import torch + +from transformers.utils import WEIGHTS_NAME, WEIGHTS_INDEX_NAME +from transformers.utils import is_remote_url +from transformers.modeling_utils import load_state_dict +from transformers.utils.hub import cached_file, get_checkpoint_shard_files + + +def state_dict_from_pretrained(model_name, device=None, dtype=None): + # If not fp32, then we don't want to load directly to the GPU + mapped_device = 'cpu' if dtype not in [torch.float32, None] else device + is_sharded = False + resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is None: + resolved_archive_file = cached_file(model_name, WEIGHTS_INDEX_NAME, + _raise_exceptions_for_missing_entries=False) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None: + raise EnvironmentError(f"Model name {model_name} was not found.") + if is_sharded: + # resolved_archive_file becomes a list of files that point to the different + # checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + model_name, resolved_archive_file + ) + state_dict = {} + for sharded_file in resolved_archive_file: + state_dict.update(torch.load(sharded_file, map_location=mapped_device)) + else: + state_dict = torch.load(cached_file(model_name, WEIGHTS_NAME), map_location=device) + # Convert dtype before moving to GPU to save memory + if dtype is not None: + state_dict = {k: v.to(dtype=dtype) for k, v in state_dict.items()} + state_dict = {k: v.to(device=device) for k, v in state_dict.items()} + return state_dict diff --git a/pkgs/xformers/benchmarks/LRA/__init__.py b/pkgs/xformers/benchmarks/LRA/__init__.py new file mode 100644 index 0000000..6677db0 --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pkgs/xformers/benchmarks/LRA/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..21eb98e Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-310.pyc new file mode 100644 index 0000000..9825e12 Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-310.pyc new file mode 100644 index 0000000..3373d4a Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-310.pyc new file mode 100644 index 0000000..82efa8a Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-310.pyc new file mode 100644 index 0000000..b00696b Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-310.pyc new file mode 100644 index 0000000..79fb632 Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/batch_fetch_results.py b/pkgs/xformers/benchmarks/LRA/batch_fetch_results.py new file mode 100644 index 0000000..88227ac --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/batch_fetch_results.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +if __name__ == "__main__": + # Get the user requests + parser = argparse.ArgumentParser( + "Collect results from a given batch of distributed results" + ) + parser.add_argument("-ck", "--checkpoint_path", required=True) + args = parser.parse_args() + + logging.getLogger().setLevel(logging.INFO) + + # Go through all the data in the given repo, try to find the end results + root = Path(args.checkpoint_path) + + # - list all the mechanisms being benchmarked + results: Dict[str, Any] = {} + + for attention in filter(lambda x: x.is_dir(), root.iterdir()): + logging.info(f"\nFound results for {attention.stem}") + task_jsons = attention.glob("*/test_eval_summary.json") + results[attention.stem] = {} + + for task in task_jsons: + task_name = task.stem.split("__")[0] + logging.info(f"Logs found for task: {task_name}") + results[attention.stem][task_name] = -1 + found_result = False + + # - collect the individual results + with open(task, "r") as result_file: + dct = json.load(result_file) + if "test_accu_mean" in dct: + found_result = True + results[attention.stem][task_name] = dct["test_accu_mean"] + + logging.info( + f"Final result found for {task_name} at epoch {dct['train_step_idx']}: " + f"{results[attention.stem][task_name]}" + ) + else: + break + + # - report an error if no result was found + if not found_result: + ERR_TAIL = 30 + + logging.warning( + f"No result found for {task_name}, showing the error log in {task.parent}" + ) + err_log = Path(task.parent).glob("*.err") + print("*****************************************************") + with open(next(err_log), "r") as err_file: + for i, line in enumerate(reversed(err_file.readlines())): + print(line, end="") + if i > ERR_TAIL: + break + print("*****************************************************") + + logging.info(f"\nCollected results: {json.dumps(results, indent=2)}") + + # - reduction: compute the average + tasks = set(t for v in results.values() for t in v.keys()) + # -- fill in the possible gaps + for att in results.keys(): + for t in tasks: + if t not in results[att].keys(): + results[att][t] = 0.0 + + # -- add the average value + for att in results.keys(): + results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2) + + # - Format as an array, markdown style + tasks_sort = sorted( + set(t for v in results.values() for t in v.keys()), reverse=True + ) + print( + "{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort) + ) + + for att in results.keys(): + print( + "{0:<20}".format(att) + + "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort) + ) diff --git a/pkgs/xformers/benchmarks/LRA/batch_submit.py b/pkgs/xformers/benchmarks/LRA/batch_submit.py new file mode 100644 index 0000000..a3077aa --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/batch_submit.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os +from pathlib import Path + +from xformers.benchmarks.LRA.run_tasks import Task +from xformers.components.attention import ATTENTION_REGISTRY + + +def get_default_shared_folder() -> str: + checkpoint_paths = ["/checkpoint", "/checkpoints"] + for checkpoint_path in checkpoint_paths: + if Path(checkpoint_path).is_dir(): + return checkpoint_path + + return "." + + +if __name__ == "__main__": + default_checkpoint_path = get_default_shared_folder() + + # Get the user requests + parser = argparse.ArgumentParser( + "Benchmark different attention mechanisms on various sequence lengths" + ) + parser.add_argument("-c", "--config_path", required=True) + parser.add_argument("-ck", "--checkpoint_path", required=True) + parser.add_argument( + "-a", "--attentions", nargs="+", default=list(ATTENTION_REGISTRY.keys()) + ) + parser.add_argument("-t", "--tasks", nargs="+", default=[t.value for t in Task]) + parser.add_argument( + "--partition", default="a100", type=str, help="Partition where to submit" + ) + args = parser.parse_args() + + for attention in args.attentions: + for task in args.tasks: + os.system( + "python3 run_with_submitit.py" + + f" --attention {attention} --task {task} --config {args.config_path}" + + f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}" + + f" --partition {args.partition}" + ) diff --git a/pkgs/xformers/benchmarks/LRA/code/__init__.py b/pkgs/xformers/benchmarks/LRA/code/__init__.py new file mode 100644 index 0000000..6677db0 --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/code/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pkgs/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..a768876 Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-310.pyc new file mode 100644 index 0000000..6ee714f Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-310.pyc b/pkgs/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-310.pyc new file mode 100644 index 0000000..7394f6d Binary files /dev/null and b/pkgs/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/LRA/code/dataset.py b/pkgs/xformers/benchmarks/LRA/code/dataset.py new file mode 100644 index 0000000..cf7e003 --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/code/dataset.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: Almost as-is from the Nystromformer repo +# https://github.com/mlpen/Nystromformer + +import logging +import pickle + +import torch +from torch.utils.data.dataset import Dataset + +logging.getLogger().setLevel(logging.INFO) + + +class LRADataset(Dataset): + def __init__(self, file_path, seq_len): + with open(file_path, "rb") as f: + self.examples = pickle.load(f) + + self.seq_len = seq_len + logging.info(f"Loaded {file_path}... size={len(self.examples)}") + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + return self.create_inst(self.examples[i], self.seq_len) + + @staticmethod + def create_inst(inst, seq_len): + output = { + "input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len] + } + output["mask_0"] = (output["input_ids_0"] != 0).float() + + if "input_ids_1" in inst: + output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[ + :seq_len + ] + output["mask_1"] = (output["input_ids_1"] != 0).float() + output["label"] = torch.tensor(inst["label"], dtype=torch.long) + return output diff --git a/pkgs/xformers/benchmarks/LRA/code/model_wrapper.py b/pkgs/xformers/benchmarks/LRA/code/model_wrapper.py new file mode 100644 index 0000000..5eb3e2c --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/code/model_wrapper.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: adapted from the Nystromformer repo +# https://github.com/mlpen/Nystromformer + +from enum import Enum +from typing import Dict, Union + +import pytorch_lightning as pl +import torch +import torch.nn as nn + +from xformers.components import build_attention +from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig +from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig +from xformers.utils import generate_matching_config + +PLOutput = Dict[str, Union[float, torch.Tensor]] + + +class Pooling(str, Enum): + MEAN = "mean" + CLS = "cls" + + +def pooling(mode: Pooling): + def pool_cls(inp): + return inp[:, 0, :] + + def pool_mean(inp): + return inp.mean(dim=1) + + return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode] + + +def append_cls(inp, mask, vocab_size): + batch_size = inp.size(0) + cls_id = ( + (vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device) + ).long() + cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device) + inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1) + mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1) + return inp, mask + + +def patch_model_config(config, attention_name): + # Rebuild a specific config out of generic + extra params + commons = config["common"] + try: + extra_attention_settings = config["extra_settings"]["attention"][attention_name] + except KeyError: + extra_attention_settings = None + + for bc in config["xformer"]: + bc["dim_model"] = commons["dim_model"] + bc["position_encoding_config"].update(commons) + bc["feedforward_config"].update(commons) + bc["multi_head_config"].update(commons) + bc["multi_head_config"]["attention"].update(commons) + bc["multi_head_config"]["attention"]["name"] = attention_name + bc["multi_head_config"]["attention"]["dim_head"] = ( + commons["dim_model"] / commons["num_heads"] + ) + if extra_attention_settings is not None: + bc["multi_head_config"]["attention"].update(extra_attention_settings) + + bc["multi_head_config"] = generate_matching_config( + bc["multi_head_config"], MultiHeadDispatchConfig + ) + bc["multi_head_config"].attention = build_attention( + bc["multi_head_config"].attention + ) + bc = generate_matching_config(bc, xFormerEncoderConfig) + + return config + + +class SCHead(nn.Module): + def __init__(self, config, dim_embedding, dim_mlp): + super().__init__() + self.pooling = pooling(Pooling(config["pooling_mode"])) + + self.mlpblock = nn.Sequential( + nn.Linear(dim_embedding, dim_mlp), + nn.ReLU(), + nn.Linear(dim_mlp, config["common"]["num_classes"]), + ) + + def forward(self, inp: torch.Tensor): + seq_score = self.mlpblock(self.pooling(inp)) + return seq_score + + +class SCHeadDual(nn.Module): + def __init__(self, config, dim_embedding, dim_mlp): + super().__init__() + self.pooling = pooling(Pooling(config["pooling_mode"])) + + self.mlpblock = nn.Sequential( + nn.Linear( + dim_embedding * 4, + dim_mlp, + ), + nn.ReLU(), + nn.Linear(dim_mlp, config["common"]["num_classes"]), + ) + + def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor): + X_0 = self.pooling(inp_0) + X_1 = self.pooling(inp_1) + seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1)) + return seq_score + + +class ModelTrunk(pl.LightningModule): + def __init__(self, config, model_name): + super().__init__() + + config_model = config["model"] + self.config_training = config["training"] + + self.enable_amp = config["training"]["mixed_precision"] + self.pooling_mode = Pooling(config_model["pooling_mode"]) + self.vocab_size = config_model["common"]["vocab_size"] + + # Rebuild a specific config out of generic + extra params + self.config_model = patch_model_config(config_model, model_name) + self.model = xFormer.from_config(xFormerConfig(config_model["xformer"])) + self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"]) + + ff_config = self.config_model["xformer"][0]["feedforward_config"] + self.dim_mlp = ( + self.config_model["common"]["dim_model"] + * ff_config["hidden_layer_multiplier"] + ) + + def training_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + outputs = self(**batch) + self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("train_accu", outputs["accu"], sync_dist=True) + return outputs + + def training_epoch_end(self, outputs): + logs = self.eval_epoch_end(outputs) + self.log("train_accu_mean", logs["accu"], sync_dist=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.config_training["learning_rate"], + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=self.config_training["weight_decay"], + ) + + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=self.config_training["learning_rate"], + pct_start=self.config_training["warmup"] + / self.config_training["num_train_steps"], + anneal_strategy=self.config_training["lr_decay"], + total_steps=self.config_training["num_train_steps"], + ) + + return [optimizer], [lr_scheduler] + + def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput: + outputs = self(**batch) + return outputs + + def eval_epoch_end(self, outputs, prefix: str = "train"): + logs = {} + counts = torch.tensor([x["count"] for x in outputs]).float() + logs["count"] = counts.sum() + for k in ("accu", "loss"): + logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[ + "count" + ] + self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True) + return logs + + def validation_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + outputs = self.eval_step(batch, batch_idx) + self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True) + return outputs + + def validation_epoch_end(self, outputs): + self.eval_epoch_end(outputs, prefix="val") + + def test_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + return self.eval_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + self.eval_epoch_end(outputs, prefix="test") + + +class ModelForSC(ModelTrunk): + def __init__(self, config, model_name): + # Setup trunk + super().__init__(config, model_name) + + self.seq_classifer = SCHead( + self.config_model, + dim_embedding=self.config_model["common"]["dim_model"], + dim_mlp=self.dim_mlp, + ) + + def forward( # type: ignore + self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor + ): + + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + + token_out = self.norm( + self.model(input_ids_0, encoder_input_mask=mask_0) + ) * mask_0.unsqueeze(-1) + + seq_scores = self.seq_classifer(token_out) + + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } + + return outputs + + +class ModelForSCDual(ModelTrunk): + def __init__(self, config, model_name): + # Setup trunk + super().__init__(config, model_name) + + self.seq_classifer = SCHeadDual( + self.config_model, + dim_embedding=self.config_model["common"]["dim_model"], + dim_mlp=self.dim_mlp, + ) + + def forward( # type: ignore + self, + input_ids_0: torch.Tensor, + input_ids_1: torch.Tensor, + mask_0: torch.Tensor, + mask_1: torch.Tensor, + label: torch.Tensor, + ): + + mask_0, mask_1 = mask_0.long(), mask_1.long() + + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) + + # Concatenate the two inputs into one batch + input_ids = torch.cat([input_ids_0, input_ids_1], dim=0) + masks = torch.cat([mask_0, mask_1], dim=0) + + tokens_out = self.norm( + self.model(input_ids, encoder_input_mask=masks) + ) * masks.unsqueeze(-1) + + seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0)) + + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } + + return outputs diff --git a/pkgs/xformers/benchmarks/LRA/run_grid_search.py b/pkgs/xformers/benchmarks/LRA/run_grid_search.py new file mode 100644 index 0000000..a30b9e4 --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/run_grid_search.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import os +import uuid +from datetime import date +from pathlib import Path +from typing import Dict, Iterable + +import submitit + +from xformers.benchmarks.LRA.run_with_submitit import ( + Trainer, + get_init_file, + get_shared_folder, + parse_args, +) + + +def grid_parameters(grid: Dict): + """ + Yield all combinations of parameters in the grid (as a dict) + """ + grid_copy = dict(grid) + # Turn single value in an Iterable + for k in grid_copy: + if not isinstance(grid_copy[k], Iterable): + grid_copy[k] = [grid_copy[k]] + for p in itertools.product(*grid_copy.values()): + yield dict(zip(grid.keys(), p)) + + +def grid_search(args): + if args.checkpoint_dir == "": + args.checkpoint_dir = get_shared_folder() / "%j" + + date_curr = date.today().strftime("%m-%d-%Y") + orig_check_dir = os.path.join(args.checkpoint_dir, date_curr) + + # Create the executor + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor( + folder=get_shared_folder() / "%j", slurm_max_num_timeout=30 + ) + num_gpus_per_node = args.ngpus + nodes = args.nodes + args.world_size = args.nodes * args.ngpus + partition = args.partition + + executor.update_parameters( + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=60 * 72, + slurm_signal_delay_s=120, + slurm_partition=partition, + ) + executor.update_parameters(name="lra") + + if args.task == "text": + grid_meta = { + "training:learning_rate": ( + [1e-4, 2e-4, 3e-4, 5e-5], + lambda val: f"lr{val}", + ), + "training:warmup": ([3000, 8000], lambda val: f"warmup{val}"), + "training:seed": ([1234, 32, 1994], lambda val: f"seed{val}"), + "training:weight_decay": ([0.02, 0.05, 0.01], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0, 0.05], lambda val: f"drop{val}"), + } + elif args.task == "retrieval": + grid_meta = { + "training:learning_rate": ([1e-4, 3e-4], lambda val: f"lr{val}"), + "training:warmup": ([2000, 8000], lambda val: f"warmup{val}"), + "training:seed": ([4096, 1234, 3, 15, 5], lambda val: f"seed{val}"), + "training:weight_decay": ([0.01, 0], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0], lambda val: f"drop{val}"), + } + elif args.task == "listops": + grid_meta = { + "training:learning_rate": ( + [1e-4, 2e-4, 3e-4, 5e-5], + lambda val: f"lr{val}", + ), + "training:warmup": ([3000, 2000], lambda val: f"warmup{val}"), + "training:seed": ( + [ + 1234, + ], + lambda val: f"seed{val}", + ), + "training:weight_decay": ([0.02, 0.05, 0, 1], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0], lambda val: f"drop{val}"), + } + else: + grid_meta = { + "training:learning_rate": ([1e-4, 5e-5], lambda val: f"lr{val}"), + "training:warmup": ([8000], lambda val: f"warmup{val}"), + "training:seed": ([1234, 4321, 3], lambda val: f"seed{val}"), + "training:weight_decay": ([0.01], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0.1], lambda val: f"drop{val}"), + } + + grid = {k: v[0] for k, v in grid_meta.items()} + save_key = {k: v[1] for k, v in grid_meta.items()} + + hyper_parameters = list(grid_parameters(grid)) + jobs = [] + + for i, grid_data in enumerate(hyper_parameters): + + args.sweep_parameters = grid_data + run_name = f"{args.attention}" + # run_name = "paper_config" + for k, v in grid_data.items(): + run_name += "prenorm-" + save_key[k](v) + args.checkpoint_dir = os.path.join( + orig_check_dir, f"{args.task}", "logs", run_name + ) + Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) + args.tb_dir = os.path.join(orig_check_dir, f"{args.task}", "tb", run_name) + Path(args.tb_dir).mkdir(parents=True, exist_ok=True) + + # Chronos needs a different job name each time + executor.update_parameters(name=f"lra_{args.task}_{i:02d}_{uuid.uuid4().hex}") + + args.dist_url = get_init_file().as_uri() + args.temp_file = str(get_init_file()) + + trainer = Trainer(args) + job = executor.submit(trainer) + jobs.append(job) + print(f"Run {i:02d} submitted with train cfg: {args}") + print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}") + + +if __name__ == "__main__": + args = parse_args() + grid_search(args) diff --git a/pkgs/xformers/benchmarks/LRA/run_tasks.py b/pkgs/xformers/benchmarks/LRA/run_tasks.py new file mode 100644 index 0000000..e9d1f72 --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/run_tasks.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import json +import logging +import os +from enum import Enum +from pathlib import Path +from typing import Dict, Tuple, cast + +import pytorch_lightning as pl +import torch +import torch.nn as nn +from fvcore.nn import FlopCountAnalysis, flop_count_str +from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.strategies import DDPStrategy +from torch.utils.data import DataLoader + +from xformers.benchmarks.LRA.code.dataset import LRADataset +from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual +from xformers.components.attention import ATTENTION_REGISTRY + + +class Task(str, Enum): + Retrieval = "retrieval" + ListOps = "listops" + Image = "image" + PathfinderBaseline = "pathfinder32-curv_baseline" + PathfinderContour9 = "pathfinder32-curv_contour_length_9" + PathfinderContour14 = "pathfinder32-curv_contour_length_14" + Text = "text" + + +def load_config(path: str) -> Dict: + with open(Path(path).absolute(), "r") as fileio: + config = json.load(fileio) + + # Duplicate the pathfinder configs + config["pathfinder32-curv_baseline"] = config["pathfinder32"] + config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"] + config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"] + return config + + +def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: + task = args.task + attention_name = args.attention + + model = cast( + pl.LightningModule, + ModelForSCDual(config[f"{task}"], attention_name) + if task == Task.Retrieval + else ModelForSC(config[f"{task}"], attention_name), + ) + + logging.info(model) + summary = pl.utilities.model_summary.LayerSummary(model) + logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M") + + with torch.no_grad(): + # Check the flops + seq_len = config[f"{task}"]["model"]["common"]["seq_len"] + x = torch.rand(1, seq_len).long() + mask = torch.rand(1, seq_len).long() + indices = torch.rand(1, seq_len).long() + flops = FlopCountAnalysis(model.model, (x, mask, indices)) + logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") + logging.info(flop_count_str(flops)) + + return model + + +def get_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--attention", + type=str, + help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \ + A list can be passed to test several mechanisms in sequence", + dest="attention", + required=True, + ) + parser.add_argument( + "--task", + type=Task, + help=f"Task to chose, among {[t.value for t in Task]}.", + dest="task", + required=True, + ) + parser.add_argument( + "--skip_train", + type=bool, + help="Whether to skip training, and test an existing model", + dest="skip_train", + default=False, + ) + parser.add_argument( + "--config", + type=str, + help="Path to the config being used", + dest="config", + default="./config.json", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the checkpoint directory", + dest="checkpoint_dir", + default=f"/checkpoints/{os.getenv('USER')}/xformers", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Path to checkpoint", + ) + parser.add_argument( + "--debug", + help="Make it easier to debug a possible issue", + dest="debug", + default=False, + action="store_true", + ) + parser.add_argument( + "--world_size", + help="Number of GPUs used", + dest="world_size", + type=int, + default=1, + ) + parser.add_argument( + "--sweep_parameters", + help="Rewrite some hyperparameters in the config", + dest="sweep_parameters", + type=dict, + default=None, + ) + return parser + + +def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]: + experiment_name = f"{task}__{attention_name}" + logger = TensorBoardLogger( + save_dir=args.checkpoint_dir, + name="", # remove lightning_logs subdirectory + version=experiment_name, + ) + log_dir = os.path.join(logger._save_dir, experiment_name) + return log_dir, logger + + +def rewrite_hyper(config, rewrites): + def replace(config_dict, k, v): + if len(k.split(":")) == 1: + config_dict[k] = v + return + first_key = k.split(":")[0] + assert first_key in config_dict, first_key + k = k[len(first_key) + 1 :] + replace(config_dict[first_key], k, v) + + for k, v in rewrites.items(): + replace(config, k, v) + return config + + +def build_dataloaders( + args: argparse.Namespace, + config_training: Dict, + num_workers: int = 4, +) -> Dict[str, DataLoader]: + datasets = {} + for component in ("train", "dev", "test"): + datasets[component] = LRADataset( + file_path=f"datasets/{args.task}.{component}.pickle", + seq_len=config_training["seq_len"], + ) + + # Gradient accumulation + accumu_steps = config_training["gradient_accumulation"] + logging.info(f"accumu_steps={accumu_steps}") + + # Batch size + per_gpu_batch_size = ( + config_training["batch_size"] // args.world_size // accumu_steps + ) + logging.warning( + f"Requested batch size: {config_training['batch_size']}. Given world\ + size and grad accumulation, per-gpu batch is\ + {per_gpu_batch_size}" + ) + + dataloaders = { + k: DataLoader( + v, + batch_size=per_gpu_batch_size, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + ) + for k, v in datasets.items() + } + return dataloaders + + +def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]: + eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step} + for k, v in trainer.callback_metrics.items(): + eval_summary[k] = v.item() + return eval_summary + + +class BasicProgressBar(TQDMProgressBar): + def get_metrics(self, trainer, model): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + +def benchmark(args): + log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}") + args.logger = logger + + config = load_config(args.config) + + config_task = config[f"{args.task}"] + if args.sweep_parameters is not None: + logging.info("Replacing hyperparameters") + rewrite_hyper(config_task, args.sweep_parameters) + + config_training = config_task["training"] + config_training["seq_len"] = config_task["model"]["common"]["seq_len"] + logging.info(f"Learning rate: {config_training['learning_rate']}") + + pl.seed_everything(config_training.get("seed", 0)) + dataloaders = build_dataloaders(args, config_training) + + model = build_model(args, config) + + progress_bar = BasicProgressBar() + checkpoint_callback = ModelCheckpoint( + monitor="val_accu", + mode="max", + dirpath=args.checkpoint_dir, + filename="{epoch}-{val_accu:.2f}", + every_n_train_steps=config_training["eval_frequency"], + ) + + trainer = pl.Trainer( + accelerator="gpu", + strategy=DDPStrategy(find_unused_parameters=args.debug) + if not args.skip_train + else None, + accumulate_grad_batches=config_training["gradient_accumulation"], + callbacks=[progress_bar, checkpoint_callback], + detect_anomaly=args.debug, + deterministic=True, + gpus=args.world_size, + limit_val_batches=config_training["num_eval_steps"], + logger=logger, + max_steps=config_training["num_train_steps"], + num_sanity_val_steps=int(not args.skip_train), + precision=16 if config_training["mixed_precision"] else 32, + val_check_interval=config_training["eval_frequency"] + / float(len(dataloaders["train"])), + ) + + if not args.skip_train: + trainer.fit( + model, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["dev"], + ) + ckpt_path = checkpoint_callback.best_model_path + else: + ckpt_path = args.checkpoint_path + + trainer.test( + model, + dataloaders=dataloaders["test"], + ckpt_path=ckpt_path, + ) + eval_summary = get_eval_summary(trainer) + with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f: + logging.info(f"Saving test results at {f.name}") + json.dump(eval_summary, f) + + +if __name__ == "__main__": + parser = get_arg_parser() + args = parser.parse_args() + if args.skip_train and args.checkpoint_path is None: + raise parser.error("Must provide --checkpoint_path if --skip_train=True") + benchmark(args) diff --git a/pkgs/xformers/benchmarks/LRA/run_with_submitit.py b/pkgs/xformers/benchmarks/LRA/run_with_submitit.py new file mode 100644 index 0000000..6fefdd1 --- /dev/null +++ b/pkgs/xformers/benchmarks/LRA/run_with_submitit.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +""" +A script to run multinode training with submitit. +Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py +""" + +import argparse +import os +import uuid +from pathlib import Path + +import submitit + +from xformers.benchmarks.LRA.run_tasks import benchmark, get_arg_parser + + +def parse_args(): + parser = argparse.ArgumentParser( + "Submitit for LRA", parents=[get_arg_parser()], add_help=False + ) + parser.add_argument( + "--ngpus", default=1, type=int, help="Number of gpus to request on each node" + ) + parser.add_argument( + "--nodes", default=1, type=int, help="Number of nodes to request" + ) + parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") + + parser.add_argument( + "--partition", default="a100", type=str, help="Partition where to submit" + ) + parser.add_argument( + "--use_volta32", action="store_true", help="Big models? Use this" + ) + parser.add_argument( + "--enforce_host_memory", action="store_true", help="Use if the host OOMs" + ) + + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message", + ) + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + checkpoint_paths = ["/checkpoint", "/checkpoints"] + for checkpoint_path in checkpoint_paths: + if Path(checkpoint_path).is_dir(): + p = Path(f"{checkpoint_path}/{user}/xformers/submitit") + p.mkdir(exist_ok=True, parents=True) + return p + raise RuntimeError(f"No shared folder available - considering {checkpoint_paths}") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer: + def __init__(self, args): + self.args = args + + def __call__(self): + self._setup_gpu_args() + benchmark(self.args) + + def checkpoint(self): + self.args.dist_url = get_init_file().as_uri() + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + job_env = submitit.JobEnvironment() + self.args.checkpoint_dir = Path( + str(self.args.checkpoint_dir).replace("%j", str(job_env.job_id)) + ) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.checkpoint_dir == "": + args.checkpoint_dir = get_shared_folder() / "%j" + Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor( + folder=args.checkpoint_dir, slurm_max_num_timeout=30 + ) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + args.world_size = args.nodes * args.ngpus + + partition = args.partition + + kwargs = { + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "timeout_min": timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + "slurm_partition": partition, + "slurm_signal_delay_s": 120, + } + + if args.enforce_host_memory: + kwargs["mem_gb"] = (40 * num_gpus_per_node,) + + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" + + if args.comment: + kwargs["slurm_comment"] = args.comment + + executor.update_parameters( + **kwargs, + ) + + executor.update_parameters(name="lra") + + args.dist_url = get_init_file().as_uri() + args.temp_file = str(get_init_file()) + + trainer = Trainer(args) + job = executor.submit(trainer) + + print(f"Submitted job_id: {job.job_id}") + print(f"Logs and checkpoints will be saved at: {args.checkpoint_dir}") + with open(Path(f"{args.checkpoint_dir}") / Path("jobs.txt"), "a") as jobfile: + jobfile.write(f"{job.job_id}\n") + + +if __name__ == "__main__": + main() diff --git a/pkgs/xformers/benchmarks/__init__.py b/pkgs/xformers/benchmarks/__init__.py new file mode 100644 index 0000000..6677db0 --- /dev/null +++ b/pkgs/xformers/benchmarks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pkgs/xformers/benchmarks/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..eeae8a7 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-310.pyc new file mode 100644 index 0000000..85e14ac Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_blocksparse_transformers.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_blocksparse_transformers.cpython-310.pyc new file mode 100644 index 0000000..f03e7f2 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_blocksparse_transformers.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_causal_blocksparse.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_causal_blocksparse.cpython-310.pyc new file mode 100644 index 0000000..55c9b1d Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_causal_blocksparse.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_core.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_core.cpython-310.pyc new file mode 100644 index 0000000..48cca08 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_core.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-310.pyc new file mode 100644 index 0000000..10d0153 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-310.pyc new file mode 100644 index 0000000..49d4541 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_mem_eff_attn_decoder.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_mem_eff_attn_decoder.cpython-310.pyc new file mode 100644 index 0000000..6ed16b6 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_mem_eff_attn_decoder.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_mlp.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_mlp.cpython-310.pyc new file mode 100644 index 0000000..418786c Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_mlp.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-310.pyc new file mode 100644 index 0000000..a23d010 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-310.pyc new file mode 100644 index 0000000..ea45fe4 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-310.pyc new file mode 100644 index 0000000..bf4bcfc Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-310.pyc new file mode 100644 index 0000000..f7ea6c8 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-310.pyc new file mode 100644 index 0000000..d38387b Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_transformer.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_transformer.cpython-310.pyc new file mode 100644 index 0000000..405017e Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_transformer.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_blocksparse.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_blocksparse.cpython-310.pyc new file mode 100644 index 0000000..bfc2596 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_blocksparse.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_dropout.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_dropout.cpython-310.pyc new file mode 100644 index 0000000..8d86e3a Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_dropout.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_fused_linear.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_fused_linear.cpython-310.pyc new file mode 100644 index 0000000..ab1babc Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_fused_linear.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_layernorm.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_layernorm.cpython-310.pyc new file mode 100644 index 0000000..bfcc991 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_layernorm.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_softmax.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_softmax.cpython-310.pyc new file mode 100644 index 0000000..c856163 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_softmax.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_stride_sum.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_stride_sum.cpython-310.pyc new file mode 100644 index 0000000..77b9e25 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/benchmark_triton_stride_sum.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/__pycache__/utils.cpython-310.pyc b/pkgs/xformers/benchmarks/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..6e2ab85 Binary files /dev/null and b/pkgs/xformers/benchmarks/__pycache__/utils.cpython-310.pyc differ diff --git a/pkgs/xformers/benchmarks/benchmark_attn_decoding.py b/pkgs/xformers/benchmarks/benchmark_attn_decoding.py new file mode 100644 index 0000000..eae7464 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_attn_decoding.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from typing import Any + +import torch +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops as xops + +min_run_time = 0.5 +device = torch.device("cuda") + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = [ + dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=1, K=128) + for i in range(8, 18) +] + [ + dict(B=max(1, 2 ** (16 - i)), Mq=1, Mkv=2**i, Hq=16, Hkv=2, K=128) + for i in range(8, 18) +] + + +def _setup_test( + functions, fw: bool = False, bw: bool = False, cuda_graph: bool = True, **kwargs +): + for k, benchmark_cls in functions.items(): + benchmark_object = benchmark_cls(**kwargs, bw=bw) + label = benchmark_object.label + label += "fw" if fw else "" + label += "bw" if bw else "" + + def run_one(): + if fw: + benchmark_object.fw() + if bw: + benchmark_object.bw() + + if cuda_graph: + run_one() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + run_one() + + def run_one(): + g.replay() + + yield benchmark.Timer( + stmt="fn()", + globals={ + "fn": run_one, + }, + label=label, + description=k, + sub_label=benchmark_object.sub_label, + ) + + +class AttentionDecodingFlashDecoding: + OP: Any = xops.fmha.flash.FwOp + + def __init__( + self, B: int, Mq: int, Mkv: int, Hq: int, Hkv: int, K: int, bw: bool + ) -> None: + dtype = torch.float16 + self.sub_label = f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K}" + self.label = "attn_decoding" + self.shapes = (B, Mq, Mkv, Hq, Hkv, K) + + assert Hkv <= Hq + assert Hq % Hkv == 0 + self.q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw + ) + self.k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ).expand(-1, -1, -1, Hq // Hkv, -1) + self.v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ).expand(-1, -1, -1, Hq // Hkv, -1) + + if Hq == Hkv: + self.q = self.q[:, :, :, 0] + self.k = self.k[:, :, :, 0] + self.v = self.v[:, :, :, 0] + if Hkv == 1: + self.q = self.q[:, :, 0] + self.k = self.k[:, :, 0] + self.v = self.v[:, :, 0] + + def fw(self) -> None: + xops.memory_efficient_attention_forward(self.q, self.k, self.v, op=self.OP) + + +class AttentionDecodingSplitKV(AttentionDecodingFlashDecoding): + OP = xops.fmha.triton_splitk.FwOp + + +class AttentionDecodingPyTorchRepeat(AttentionDecodingFlashDecoding): + def fw(self) -> None: + B, Mq, Mkv, Hq, Hkv, K = self.shapes + scale = 1 / K**0.5 + q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2)).softmax(-1) * scale + return attn @ v + + +BENCHMARKS = { + "pytorch": AttentionDecodingPyTorchRepeat, + "flash-decoding": AttentionDecodingFlashDecoding, + "triton_splitK": AttentionDecodingSplitKV, +} + + +try: + import flash_attn + + class AttentionDecodingFlashAttention(AttentionDecodingFlashDecoding): + def fw(self) -> None: + q, k, v = self.q, self.k, self.v + if q.ndim == 5: + B, Mq, H1, H2, K = q.shape + B, Mkv, H1, H2, K = k.shape + q = q.reshape([B, Mq, H1 * H2, K]) + k = k[:, :, :, 0] + v = v[:, :, :, 0] + return flash_attn.flash_attn_func(q, k, v) + + BENCHMARKS[ + f"flash-attention@{flash_attn.__version__}" + ] = AttentionDecodingFlashAttention +except ImportError: + pass + + +def attn_decoding(**kwargs): + yield from _setup_test( + **kwargs, + fw=True, + cuda_graph=True, + functions=BENCHMARKS, + ) + + +benchmark_main_helper(attn_decoding, CASES, min_run_time=min_run_time) diff --git a/pkgs/xformers/benchmarks/benchmark_blocksparse_transformers.py b/pkgs/xformers/benchmarks/benchmark_blocksparse_transformers.py new file mode 100644 index 0000000..f9cb72a --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_blocksparse_transformers.py @@ -0,0 +1,1065 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import gc +import math +from collections import namedtuple +from dataclasses import dataclass + +import matplotlib.pyplot as plt +import torch +import triton +from triton.ops.blocksparse import matmul as blocksparse_matmul + +from xformers.benchmarks.utils import pretty_barplot +from xformers.components.attention.attention_patterns import ( + axial_2d_pattern, + causal_1d_pattern, + global_token_pattern, + local_1d_pattern, + local_2d_pattern, +) +from xformers.components.attention.core import SparseCS, _matmul_with_mask + +device = "cuda" +TestCase = namedtuple("TestCase", ["prepare_callable", "mask", "config", "name"]) + +############################################## +# Plotting utilities +############################################## + + +def plot_mask(mask, config, filename): + sparsity = get_sparsity(mask) + batch_size = config.batch_size + num_heads = config.num_heads + seq_len = config.seq_length + + proxy = torch.ones(batch_size, num_heads, seq_len, seq_len, dtype=torch.bool) + proxy = triton.testing.mask_tensor(proxy, mask, config.block_size, False) + proxy = proxy[0][0] + + f = plt.figure() + plt.imshow(proxy.logical_not(), cmap="gray") + plt.suptitle("Sparsity = " + str(sparsity) + "%") + plt.savefig(filename) + plt.close(f) + + +############################################## +# Mask and testing utilities +############################################## + + +def get_mask(MaskGenType, config, config_setter=[]): + mask_config = Configuration() + mask_config.init(config) + + # Get the mask + mask_generator = MaskGenType(mask_config) + for (key, value) in config_setter: + mask_generator.set_config_attr(key, value) + if not mask_generator.is_valid_config(): + return None + + return mask_generator() + + +def densify_mask(mask, config): + num_heads = config.num_heads + seq_length = config.seq_length + block_size = config.block_size + dense_mask = torch.zeros(num_heads, seq_length, seq_length) + for (h, i, j) in zip(*mask.nonzero(as_tuple=True)): + dense_mask[ + h, + i * block_size : (i + 1) * block_size, + j * block_size : (j + 1) * block_size, + ] = mask[h, i, j] + return dense_mask + + +def mask_tensor(a, mask, config): + return triton.testing.mask_tensor(a, mask, config.block_size, 0.0) + + +def sparsify_tensor(a, mask, config): + return triton.testing.sparsify_tensor(a, mask, config.block_size) + + +def get_sparsity(mask): + return round((1.0 - mask.sum().item() / mask.numel()) * 100) + + +############################################## +# Mask Generation +############################################## + + +@dataclass +class Configuration: + batch_size: int = 32 + num_heads: int = 12 + seq_length: int = 2048 + hidden_size: int = 768 # hidden_size = n_heads * projection_hidden_dimension + block_size: int = 64 + + @property + def blocked_seq_length(self): + return int(self.seq_length / self.block_size) + + def init(self, kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def __str__(self): + desc = [ + f"bs={self.batch_size}", + f"h={self.num_heads}", + f"k={self.hidden_size}", + f"seq={self.seq_length}", + f"bl={self.block_size}", + ] + return ",".join(desc) + + +class AttentionMask: + def __init__(self, config=None): + super().__init__() + if config is None: + config = Configuration() + self.config = config + + def is_blocked(self): + return self.config.block_size != 1 + + def is_valid_config(self, keep_blocked=True): + return True + + def expand(self, mask): + if mask.ndim == 2: + return mask.unsqueeze(0).expand(self.config.num_heads, -1, -1) + + def gen_mask(self, keep_blocked=True): + raise NotImplementedError("Abstract data class") + + def set_config_attr(self, key, value): + setattr(self.config, key, value) + + def __str__(self): + raise NotImplementedError("Abstract data type") + + def __call__(self): + mask = self.gen_mask() + return mask, self.config, str(self) + + +class RandomAttentionMask(AttentionMask): + """ + This is a Random mask. Useful for performance and memory analysis. + """ + + def __init__(self, config=None): + super(RandomAttentionMask, self).__init__(config) + self.set_config_attr("mask_prob", 0.5) + + def gen_mask(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + + mask = torch.rand(seq_length, seq_length) > self.config.mask_prob + return self.expand(mask) + + def __str__(self): + return "random" + + +class LowerTriangularAttentionMask(AttentionMask): + """ + This is a lower triangular mask. This is common in decoder only models. + This should reduce the computation and memory to roughly half as close to + half of the mask is zero. + + The mask stays same for each head and each input. + + Nit pick (TODO) - While blocking, we need to ensure that the blocks along + the diagonals are themselves lower triangular blocks. But, for performance + measurement, this is ok to ignore as we treat the whole block as useful + values. + """ + + def __init__(self, config=None): + super(LowerTriangularAttentionMask, self).__init__(config) + + def gen_mask(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + return self.expand(causal_1d_pattern(seq_length)) + + def __str__(self): + return "lower_triangular" + + +class BigBirdAttentionMask(AttentionMask): + """ + BigBird mask are composed of three types of masks - random, global and window. + For more details, refer to https://arxiv.org/pdf/2007.14062.pdf + + One point to note is that mask is per head here. So, mask is 3D tensor. + (num_heads, seq_length, seq_length). + """ + + def __init__(self, config=None): + super(BigBirdAttentionMask, self).__init__(config) + self.mask_per_head = True + self.set_config_attr("num_global_tokens", 2 * self.config.block_size) + self.set_config_attr("num_random_tokens", 3 * self.config.block_size) + self.set_config_attr("num_window_tokens", 3 * self.config.block_size) + + def gen_global_mask(self, seq_length): + # Global tokens are tokens that attend to all tokens and to whom all tokens attend to in the sequence + num_global_blocks = self.config.num_global_tokens // self.config.block_size + mask_indices = torch.randint(0, seq_length - 1, size=(num_global_blocks,)) + mask_indices = torch.unique(mask_indices) + query_mask = torch.zeros(seq_length).to(dtype=torch.bool) + query_mask.scatter_(0, mask_indices, True) + return global_token_pattern(query_mask) + + def gen_random_mask(self, seq_length): + # Each query token attends over r random number of tokens + num_random_blocks = self.config.num_random_tokens // self.config.block_size + mask_indices = torch.randint( + 0, seq_length - 1, size=(seq_length, num_random_blocks) + ) + random_mask = torch.zeros(seq_length, seq_length).to(dtype=torch.bool) + random_mask.scatter_(1, mask_indices, True) + return random_mask + + def gen_window_mask(self, seq_length): + num_window_blocks = self.config.num_window_tokens // self.config.block_size + if num_window_blocks % 2 == 0: + num_window_blocks += 1 + return local_1d_pattern(seq_length, num_window_blocks) + + def gen_mask(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + assert keep_blocked, "Not implemented, call to_dense later to get full tensor" + + if self.mask_per_head: + head_masks = [] + + for _ in range(self.config.num_heads): + global_mask = self.gen_global_mask(seq_length) + random_mask = self.gen_random_mask(seq_length) + window_mask = self.gen_window_mask(seq_length) + + mask = global_mask + random_mask + window_mask + head_masks.append(mask) + + mask = torch.stack(head_masks) + else: + global_mask = self.gen_global_mask(seq_length) + random_mask = self.gen_random_mask(seq_length) + window_mask = self.gen_window_mask(seq_length) + + mask = global_mask + random_mask + window_mask + mask = self.expand(mask) + return mask + + def __str__(self): + return "bigbird" + + +class AxialAttentionMask(AttentionMask): + """ + BigBird mask are composed of three types of masks - random, global and window. + For more details, refer to https://arxiv.org/pdf/2007.14062.pdf + + One point to note is that mask is per head here. So, mask is 3D tensor. + (num_heads, seq_length, seq_length). + """ + + def __init__(self, config=None): + super(AxialAttentionMask, self).__init__(config) + if config is None: + self.set_config_attr("seq_length", 1024) + + def is_valid_config(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + H = int(math.sqrt(seq_length)) + if H * H == seq_length: + return True + return False + + def gen_mask(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + H = int(math.sqrt(seq_length)) + assert H * H == seq_length, f"H={H}, seq_length={seq_length}" + return self.expand(axial_2d_pattern(H, H)) + + def __str__(self): + return "axial" + + +class LocalAttentionMask(AttentionMask): + """ + BigBird mask are composed of three types of masks - random, global and window. + For more details, refer to https://arxiv.org/pdf/2007.14062.pdf + + One point to note is that mask is per head here. So, mask is 3D tensor. + (num_heads, seq_length, seq_length). + """ + + def __init__(self, config=None): + super(LocalAttentionMask, self).__init__(config) + self.set_config_attr("num_local_blocks", 3) + if config is None: + self.set_config_attr("seq_length", 1024) + + def is_valid_config(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + H = int(math.sqrt(seq_length)) + if H * H == seq_length: + return True + return False + + def gen_mask(self, keep_blocked=True): + seq_length = self.config.seq_length + if keep_blocked: + seq_length = self.config.blocked_seq_length + H = int(math.sqrt(seq_length)) + assert H * H == seq_length, f"H={H}, seq_length={seq_length}" + return self.expand(local_2d_pattern(H, H, self.config.num_local_blocks)) + + def __str__(self): + return "local" + + +############################################## +# Class to organize the experiments +############################################## + + +class Experiment: + def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik): + self.mode = mode + self.dtype = dtype + self.do_accuracy_check = do_accuracy_check + self.profile_sputnik = profile_sputnik + + def reset_results(self): + self.results = {} + self.results["flops"] = {} + self.results["time"] = {} + self.results["memory"] = {} + self.results["speedup"] = {} + self.results["memory_savings"] = {} + + def do_mem(sel, fn): + # bookeeping + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + # actually run the function + fn() + fn() + torch.cuda.synchronize() + return torch.cuda.max_memory_allocated() // 2**20 + + def gen_config(self): + raise NotImplementedError("Not setup") + + def plot(self, sparsity, pattern_name): + raise NotImplementedError("Not setup") + + def run(self): + raise NotImplementedError("Not setup") + + def add_kv(self, d, d_key, d_value, testcase): + d_value = max(0, d_value) + if d_key not in d: + d[d_key] = {} + d[d_key][testcase.name] = d_value + + def bench_all( + self, a, b, tests, mask_config, sparsity, baseline_name, op_flops, dict_key + ): + + if self.do_accuracy_check: + self.check_all(tests, a, b) + + for testcase in tests: + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + try: + fn = testcase.prepare_callable(a, b, testcase.mask, testcase.config) + ms = triton.testing.do_bench(fn)[0] + flops = op_flops / ms * 1e3 # TFlop per second + mem = self.do_mem(fn) + except Exception: + # raise + ms = -1 + flops = -1 + mem = -1 + + # Write into results + # dict_key = f"sp={sparsity}%,{mask_config}" + self.add_kv(self.results["time"], dict_key, ms, testcase) + self.add_kv(self.results["flops"], dict_key, flops, testcase) + self.add_kv(self.results["memory"], dict_key, mem, testcase) + + speedup = self.results["time"][dict_key][baseline_name] / ms + memory_savings = self.results["memory"][dict_key][baseline_name] / mem + self.add_kv(self.results["speedup"], dict_key, speedup, testcase) + self.add_kv(self.results["flops"], dict_key, flops, testcase) + self.add_kv( + self.results["memory_savings"], dict_key, memory_savings, testcase + ) + + desc = f"sparsity={sparsity}, ops={op_flops}, time={ms}, tflops={flops}, mem={mem}" + print(f"{testcase.name} --> {mask_config}, {desc}") + + def get_inputs(self, config, device="cuda"): + # if mode = sddmm, a, b = query, key + # if mode = spmm, a, b = attn, value + if self.mode == "sddmm": + return [ + torch.randn( + config.batch_size, + config.num_heads, + config.seq_length, + config.hidden_size // config.num_heads, + device=device, + dtype=self.dtype, + ) + for _ in range(2) + ] + else: + assert self.mode == "spmm" + attn = torch.randn( + config.batch_size, + config.num_heads, + config.seq_length, + config.seq_length, + device=device, + dtype=self.dtype, + ) + value = torch.randn( + config.batch_size, + config.num_heads, + config.seq_length, + config.hidden_size // config.num_heads, + device=device, + dtype=self.dtype, + ) + return [attn, value] + + def torch_matmul_callable(self, a, b, mask, config): + input_a = mask_tensor(a, mask, config) if self.mode == "spmm" else a + input_b = b.transpose(-1, -2) if self.mode == "sddmm" else b + + def torch_fn(): + return torch.matmul(input_a, input_b) + + return torch_fn + + def get_triton_fn(self, mask, config, mode="sddmm"): + if mode == "sddmm": + return blocksparse_matmul( + layout=mask, + block=config.block_size, + mode="sdd", + device="cuda", + trans_a=False, + trans_b=True, + ) + else: + assert mode == "spmm" + return blocksparse_matmul( + layout=mask, + block=config.block_size, + mode="dsd", + device="cuda", + trans_a=False, + trans_b=False, + ) + + def triton_callable(self, a, b, mask, config): + triton_kernel = self.get_triton_fn(mask, config, self.mode) + input_a = sparsify_tensor(a, mask, config) if self.mode == "spmm" else a + input_b = b + + def triton_fn(): + return triton_kernel(input_a, input_b) + + return triton_fn + + def prepare_sputnik_inputs(self, query, key, config, mask): + # - sparse / sputnik + mask_cs = torch.ones( + [config.batch_size, config.num_heads, config.seq_length, config.seq_length], + dtype=torch.bool, + device="cuda", + ) + mask_cs = triton.testing.mask_tensor( + mask_cs, mask, config.block_size, value=False + ) + # Sputnik kernels only handle fp32 + query_cs = query.flatten(start_dim=0, end_dim=1).to(torch.float32) + key_cs = key.flatten(start_dim=0, end_dim=1).to(torch.float32) + + query_cs = query_cs.contiguous() + key_cs = key_cs.transpose(-2, -1) + + sparse_mask_cs = SparseCS( + mask_cs.flatten(start_dim=0, end_dim=1).contiguous(), + device=torch.device("cuda"), + ) + return query_cs, key_cs, sparse_mask_cs + + def sputnik_callable(self, a, b, mask, config): + assert self.mode == "sddmm" + a_cs, b_cs, sparse_mask_cs = self.prepare_sputnik_inputs(a, b, config, mask) + + def sputnik_fn(): + return _matmul_with_mask(a_cs, b_cs, sparse_mask_cs) + + return sputnik_fn + + def get_op_flops(self, mask, config): + # Measure total compute ops + op_flops = ( + 2 # FMA + * config.batch_size # batched matmul + * (config.hidden_size // config.num_heads) # Reduce dimension + * float(mask.sum()) + * config.block_size + * config.block_size # Effective seq length * seq_length + * 1e-12 # TFlops + ) + return op_flops + + def check_all(self, tests, a, b): + ref_test = tests[0] + ref_out = ref_test.prepare_callable(a, b, ref_test.mask, ref_test.config)() + + res_test = tests[1] + res_out = res_test.prepare_callable(a, b, res_test.mask, res_test.config)() + + self.check_accuracy(ref_out, res_out, ref_test.mask, ref_test.config) + + def check_accuracy(self, ref_full, res_bsr, mask, config): + if self.mode == "sddmm": + # Get the dense representation of the bsr tensor + # Use triton sparse * dense multiplication to get the dense tensor back + sparse_dot_dsd = blocksparse_matmul( + layout=mask, + block=config.block_size, + mode="dsd", + device="cuda", + trans_a=False, + trans_b=False, + ) + identity = torch.eye( + config.seq_length, config.seq_length, device=device, dtype=self.dtype + ) + + identity = identity.expand(config.batch_size, config.num_heads, -1, -1) + res = sparse_dot_dsd(res_bsr, identity) + + # Get the res where values are masked. Expand the blocked mask + # ref = triton.testing.mask_tensor(ref_full, mask, config.block_size) + full_mask = densify_mask(mask, config) + ref = ref_full * full_mask.to(dtype=self.dtype, device=device) + try: + assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3) + except RuntimeError: + pass + except AssertionError: + raise + else: + assert self.mode == "spmm" + # Both are dense outputs + try: + assert torch.allclose(ref_full, res_bsr, atol=1e-3, rtol=1e-3) + except RuntimeError: + pass + except AssertionError: + import pdb + + pdb.set_trace() + raise + + +class DifferentPatternExperiment(Experiment): + """ + In this experiment, we check if sparsity pattern (like bigbird, lower triangular + etc) changes the performance of different kernels. The idea is to check if + changing sparsity pattern, while keeping total sparsity ratio same, leads to any + perforamnce differences. + + We will perform two experiments + + 1) LowerTraingularMask vs RandomMask - Both have ~50% sparsity. + 2) BigBird Mask vs RandomMask - Both have same sparsity. + """ + + def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False): + super(DifferentPatternExperiment, self).__init__( + mode, dtype, do_accuracy_check, profile_sputnik + ) + + def gen_config(self): + batch_sizes = [32] + heads = [16] + seq_lengths = [1024, 2048] + block_sizes = [64] + hidden_sizes = [1024, 4096, 8192] + for batch in batch_sizes: + for hidden_size in hidden_sizes: + for head in heads: + for seq in seq_lengths: + for block in block_sizes: + entry = { + "batch_size": batch, + "num_heads": head, + "seq_length": seq, + "block_size": block, + "hidden_size": hidden_size, + } + yield entry + + def plot(self, sparsity, config, pattern_name): + desc = [ + f"bs={config.batch_size}", + f"nheads={config.num_heads}", + f"block={config.block_size}", + f"dtype={self.dtype}", + ] + title_suffix = ",".join(desc) + pretty_barplot( + self.results["speedup"], + title=f"{self.mode} - Pattern experiment ({sparsity}%) - speedup\n" + + title_suffix, + filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_time.svg", + dash_key="pytorch", + units="Speedup normalized to torch_matmul", + ) + + pretty_barplot( + self.results["flops"], + title=f"{self.mode} - Pattern experiment ({sparsity}%) - throughput\n" + + title_suffix, + filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_flops.svg", + dash_key="pytorch", + units="TFlops/s", + ) + + pretty_barplot( + self.results["memory_savings"], + title=f"{self.mode} - Pattern experiment ({sparsity}%) - memory savings\n" + + title_suffix, + filename=f"same_sparsity_{self.mode}_{self.dtype}_{pattern_name}_memory.svg", + dash_key="pytorch", + units="Memory savings normalized to torch_matmul", + ) + + def run(self): + for MaskGenType in [LowerTriangularAttentionMask, BigBirdAttentionMask]: + self.reset_results() + for config in self.gen_config(): + # Get pattern mask + pattern_mask, pattern_config, pattern_name = get_mask( + MaskGenType, config + ) + sparsity = get_sparsity(pattern_mask) + mask_prob = sparsity / 100 + + # Get random mask + random_mask, random_config, _ = get_mask( + RandomAttentionMask, + config, + [("mask_prob", mask_prob)], + ) + print(f"{pattern_name} sparsity", get_sparsity(pattern_mask)) + print("Random sparsity", get_sparsity(random_mask)) + + # Create input tensors + a, b = self.get_inputs(random_config) + + tests = [] + baseline_name = "torch-matmul" + tests.append( + TestCase( + self.torch_matmul_callable, + random_mask, + random_config, + f"{baseline_name}", + ) + ) + tests.append( + TestCase( + self.triton_callable, + random_mask, + random_config, + "triton-random", + ) + ) + tests.append( + TestCase( + self.triton_callable, + pattern_mask, + pattern_config, + f"triton-{pattern_name}", + ) + ) + if self.profile_sputnik and self.mode == "sddmm": + tests.append( + TestCase( + self.sputnik_callable, + random_mask, + random_config, + "sputnik-random", + ) + ) + tests.append( + TestCase( + self.sputnik_callable, + pattern_mask, + pattern_config, + f"sputnik-{pattern_name}", + ) + ) + + dict_key = f"hidden={random_config.hidden_size},seq_len={random_config.seq_length}" + self.bench_all( + a, + b, + tests, + random_config, + sparsity, + baseline_name, + self.get_op_flops(random_mask, random_config), + dict_key, + ) + + ideal_testcase = TestCase(None, None, None, "Ideal") + ideal_speedup = round(100 / (100 - sparsity), 1) + self.add_kv( + self.results["speedup"], dict_key, ideal_speedup, ideal_testcase + ) + self.add_kv( + self.results["memory_savings"], + dict_key, + ideal_speedup, + ideal_testcase, + ) + self.plot(sparsity, random_config, pattern_name) + + +class VarySparsityExperiment(Experiment): + """ + In this experiment, we check how sparsity ration affects the performance. + """ + + def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False): + super(VarySparsityExperiment, self).__init__( + mode, dtype, do_accuracy_check, profile_sputnik + ) + + def gen_config(self): + batch_sizes = [32] + heads = [16] + seq_lengths = [2048] + hidden_sizes = [1024, 8192] + block_sizes = [64] + for batch in batch_sizes: + for seq in seq_lengths: + for head in heads: + for block in block_sizes: + for hidden_size in hidden_sizes: + entry = { + "batch_size": batch, + "num_heads": head, + "seq_length": seq, + "block_size": block, + "hidden_size": hidden_size, + } + yield entry + + def plot(self, sparsity, config, pattern_name): + desc = [ + f"bs={config.batch_size}", + f"nheads={config.num_heads}", + f"block={config.block_size}", + f"dtype={self.dtype}", + f"seq_len={config.seq_length}", + ] + title_suffix = ",".join(desc) + pretty_barplot( + self.results["speedup"], + title=f"{self.mode} - SparsityRatio experiment speedup\n" + title_suffix, + filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_time.svg", + dash_key="pytorch", + units="Speedup normalized to torch_matmul", + ) + + pretty_barplot( + self.results["flops"], + title=f"{self.mode} - SparsityRatio experiment throughput\n" + title_suffix, + filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_flops.svg", + dash_key="pytorch", + units="TFlops/s", + ) + + pretty_barplot( + self.results["memory_savings"], + title=f"{self.mode} - SparsityRatio experiment memory savings\n" + + title_suffix, + filename=f"vary_sparsity_{self.mode}_{self.dtype}_{pattern_name}_memory.svg", + dash_key="pytorch", + units="Memory savings normalized to torch_matmul", + ) + + def run(self): + self.reset_results() + random_config = None + for config in self.gen_config(): + for x in range(10, 100, 10): + mask_prob = x / 100.0 + # Get random mask + random_mask, random_config, _ = get_mask( + RandomAttentionMask, + config, + [ + ("mask_prob", mask_prob), + ], + ) + sparsity = get_sparsity(random_mask) + print("Random sparsity", get_sparsity(random_mask)) + + # Create input tensors + a, b = self.get_inputs(random_config) + + tests = [] + baseline_name = "torch-matmul" + tests.append( + TestCase( + self.torch_matmul_callable, + random_mask, + random_config, + f"{baseline_name}", + ) + ) + tests.append( + TestCase( + self.triton_callable, + random_mask, + random_config, + "triton-random", + ) + ) + if self.profile_sputnik and self.mode == "sddmm": + tests.append( + TestCase( + self.sputnik_callable, + random_mask, + random_config, + "sputnik-random", + ) + ) + dict_key = f"sp={mask_prob},hidden={random_config.hidden_size}" + self.bench_all( + a, + b, + tests, + random_config, + sparsity, + baseline_name, + self.get_op_flops(random_mask, random_config), + dict_key, + ) + + ideal_testcase = TestCase(None, None, None, "Ideal") + ideal_speedup = round(100 / (100 - mask_prob * 100), 1) + self.add_kv( + self.results["speedup"], dict_key, ideal_speedup, ideal_testcase + ) + self.add_kv( + self.results["memory_savings"], + dict_key, + ideal_speedup, + ideal_testcase, + ) + self.plot(None, random_config, "random") + + +class BlockSizeExperiment(Experiment): + """ + In this experiment, we analyze how increasing the block size affects + performance. We will take the lower triangular pattern. As we increase the + batch size, the blocks near the diagonal have to do more unnecessary + computation (the effective sparsity starts decreasing). + """ + + def __init__(self, mode, dtype, do_accuracy_check, profile_sputnik=False): + super(BlockSizeExperiment, self).__init__( + mode, dtype, do_accuracy_check, profile_sputnik + ) + + def gen_config(self): + batch_sizes = [32] + heads = [16] + seq_lengths = [2048] + block_sizes = [32, 64, 128, 256] + hidden_sizes = [1024, 8192] + for batch in batch_sizes: + for seq in seq_lengths: + for hidden_size in hidden_sizes: + for block in block_sizes: + for head in heads: + entry = { + "batch_size": batch, + "num_heads": head, + "seq_length": seq, + "block_size": block, + "hidden_size": hidden_size, + } + yield entry + + def plot(self, sparsity, config, pattern_name): + pretty_barplot( + self.results["speedup"], + title=f"{self.mode} - BlockSize experiment speedup\n" + f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}", + filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_time.svg", + dash_key="pytorch", + units="Speedup normalized to torch matmul", + ) + + pretty_barplot( + self.results["flops"], + title=f"{self.mode} - BlockSize experiment throughput\n" + f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}", + filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_flops.svg", + dash_key="pytorch", + units="TFlops/s", + ) + + pretty_barplot( + self.results["memory_savings"], + title=f"{self.mode} - BlockSize experiment memory savings\n" + f"bs={config.batch_size}, nheads={config.num_heads}, seq_len={config.seq_length}, dtype={self.dtype}", + filename=f"vary_block_size_{self.mode}_{self.dtype}_{pattern_name}_memory.svg", + dash_key="pytorch", + units="Memory savings normalized to torch matmul", + ) + + def get_op_flops(self, mask, config): + # Op flops here refer to the original non blocked attention mask, where + # no unnecessary elements are unmasked. We can compute this by computing + # the total flops of batch matmul and then multiply by (n+1)/2n. + num_masked_elems = (config.seq_length + 1) / (2.0 * config.seq_length) + op_flops = ( + 2 # FMA + * config.batch_size # batched matmul + * config.num_heads + * (config.hidden_size // config.num_heads) # Reduce dimension + * config.seq_length + * config.seq_length + * num_masked_elems + * 1e-12 # TFlops + ) + return op_flops + + def run(self): + self.reset_results() + lt_config = None + for config in self.gen_config(): + lt_mask, lt_config, lt_name = get_mask( + LowerTriangularAttentionMask, + config, + ) + sparsity = get_sparsity(lt_mask) + print("Effective sparsity", sparsity) + if lt_config.seq_length == 2048: + plot_mask(lt_mask, lt_config, f"lt_mask_{lt_config.block_size}.svg") + + # Create input tensors + a, b = self.get_inputs(lt_config) + + tests = [] + baseline_name = "torch-matmul" + tests.append( + TestCase( + self.torch_matmul_callable, lt_mask, lt_config, f"{baseline_name}" + ) + ) + tests.append( + TestCase(self.triton_callable, lt_mask, lt_config, "triton-random") + ) + if self.profile_sputnik and self.mode == "sddmm": + tests.append( + TestCase( + self.sputnik_callable, lt_mask, lt_config, "sputnik-random" + ) + ) + dict_key = f"hidden={lt_config.hidden_size}, block={lt_config.block_size}" + self.bench_all( + a, + b, + tests, + lt_config, + sparsity, + baseline_name, + self.get_op_flops(lt_mask, lt_config), + dict_key, + ) + + ideal_testcase = TestCase(None, None, None, "Ideal") + seq_len = lt_config.seq_length + total_elems = seq_len * seq_len + nnz = seq_len * (seq_len + 1) / 2 + ideal_speedup = (1.0 * total_elems) / nnz + + self.add_kv( + self.results["speedup"], dict_key, ideal_speedup, ideal_testcase + ) + self.add_kv( + self.results["memory_savings"], + dict_key, + ideal_speedup, + ideal_testcase, + ) + self.plot(None, lt_config, lt_name) + + +if __name__ == "__main__": + for MaskGen in [ + RandomAttentionMask, + LowerTriangularAttentionMask, + BigBirdAttentionMask, + AxialAttentionMask, + LocalAttentionMask, + ]: + mask_gen = MaskGen() + mask, config, name = mask_gen() + plot_mask(mask, config, f"{name}.svg") + + for mode in ["sddmm", "spmm"]: + DifferentPatternExperiment(mode, torch.float16, True).run() + VarySparsityExperiment(mode, torch.float16, True).run() + BlockSizeExperiment(mode, torch.float16, True).run() diff --git a/pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py b/pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py new file mode 100644 index 0000000..4bd2063 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_causal_blocksparse.py @@ -0,0 +1,137 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import os +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components.attention.attention_mask import AttentionMask +from xformers.components.attention.core import scaled_dot_product_attention + +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + +SHAPES = [ + (8, 128, 2096), + (8, 1024, 256), + (12, 512, 1024), + (128, 128, 512), + (8, 2048, 4096), + (16, 1024, 5120), + (512, 128, 2560), +] + +BLOCK_SIZES = [128] +N_HEADS = [8, 32] + + +def bench_blocksparse_compare(backward: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + use_amp = True + _use_cuda = True + + for dtype in [torch.float16, torch.float32]: + datatype = "fp16" if dtype == torch.float16 else "fp32" + results: Dict[str, Any] = {} + results_mem: Dict[str, Any] = {} + for BS in BLOCK_SIZES: + for heads in N_HEADS: + for B, M, K in SHAPES: + q = torch.randn( + (B, heads, M, K // heads), + requires_grad=backward, + device=device, + dtype=dtype, + ) + k = q + v = q + + # Mask with causal flag + m_att_mask = AttentionMask.make_causal( + M, M, device=device, dtype=dtype + ) + # Custom causal tensor mask + m_custom = torch.triu( + torch.ones(M, M, device=device, dtype=dtype) * float("-inf"), + diagonal=1, + ) + + def blocksparse_attention(): + with torch.cuda.amp.autocast(enabled=use_amp): + y = scaled_dot_product_attention( + q=q, k=k, v=v, att_mask=m_att_mask, block_size=BS + ) + if backward: + torch.norm(y).backward() + return y + + def sdp_attention(): + with torch.cuda.amp.autocast(enabled=use_amp): + y = scaled_dot_product_attention( + q=q, k=k, v=v, att_mask=m_custom, block_size=BS + ) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(blocksparse_attention, f"blocksparse - fw{bw}"), + TestCase(sdp_attention, f"standard sdp - fw{bw}"), + ]: + if _use_cuda: + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + time = triton.testing.do_bench(testcase.function)[0] + + if _use_cuda: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2**20 + else: + max_memory = -1 + + key = f"B={B},M={M},K={K},NH={heads}" + + if key not in results_mem: + results_mem[key] = {} + results_mem[key][testcase.name] = f"{max_memory:.1f}" + + if key not in results: + results[key] = {} + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {datatype} Block Size: {BS} --- ", + units="runtime in ms", + ) + pretty_print( + results_mem, + title=f"\n --- Type: {datatype} Block Size: {BS} --- ", + units="peak memory usage in MB", + ) + + pretty_plot( + results, + title=f"Causal Blocksparse Runtime FW{bw.upper()} {datatype} Blocksize:{BS}", + units="runtime in ms", + dash_key="torch", + legend_loc="upper left", + ) + pretty_plot( + results_mem, + title=f"Causal Blocksparse Memory FW{bw.upper()} {datatype} Blocksize:{BS}", + units="peak memory usage in MB", + dash_key="torch", + legend_loc="upper left", + ) + + +for bw in [False, True]: + bench_blocksparse_compare(bw) diff --git a/pkgs/xformers/benchmarks/benchmark_core.py b/pkgs/xformers/benchmarks/benchmark_core.py new file mode 100644 index 0000000..97cdefa --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_core.py @@ -0,0 +1,258 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import torch +from torch.utils import benchmark + +from xformers.components.attention.core import ( + SparseCS, + _create_random_sparsity, + _matmul_with_mask, + _softmax, + bmm, +) + +MIN_RUN_TIME = 1 +SHAPES = [[8, 8], [256, 1024], [128, 256]] +SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99] + + +def bench_sddmm(): + min_run_time = MIN_RUN_TIME + SPARSITIES = [0.95, 0.98, 0.99, 0.995, 0.999] + + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, M, K, device=device) + + for backend, prob in itertools.product( + ["coo_pytorch", "csr_sputnik", "csr_ge"], SPARSITIES + ): + mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), prob) + aa = a + bb = b + if "csr" in backend: + mask = SparseCS(mask, device) + aa = a + bb = b + row_indices = mask.row_indices + row_offsets = mask.row_offsets + column_indices = mask.column_indices + if "_ge" in backend: + fn = torch.ops.xformers.csr_sddmm + else: + fn = torch.ops.xformers.sddmm_sputnik + fn_str = "fn(a, b, row_indices, row_offsets, column_indices)" + else: + mask = mask.to_sparse().to(device) + _, row_offsets, column_indices = mask.indices().int().unbind() + row_offsets = row_offsets.contiguous() + column_indices = column_indices.contiguous() + row_indices = row_offsets + + bb = b.transpose(-2, -1) + fn = _matmul_with_mask + fn_str = "fn(a, b, mask)" + + results.append( + benchmark.Timer( + stmt=fn_str, + globals={ + "a": aa, + "b": bb, + "mask": mask, + "row_indices": row_indices, + "row_offsets": row_offsets, + "column_indices": column_indices, + "fn": fn, + }, + label="sddmm", + sub_label=f"sparsity {backend}: {prob:0.4f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def bench_matmul_with_mask(): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, K, M, device=device) + mask = torch.rand(B, M, M, device=device) > prob + + results.extend( + [ + benchmark.Timer( + stmt="_matmul_with_mask(a, b, mask)", + globals={ + "a": a, + "b": b, + "mask": None, + "_matmul_with_mask": _matmul_with_mask, + }, + label="matmul_with_mask", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + benchmark.Timer( + stmt="_matmul_with_mask(a, b, mask)", + globals={ + "a": a, + "b": b, + "mask": mask, + "_matmul_with_mask": _matmul_with_mask, + }, + label="matmul_with_mask", + sub_label="dense with masking", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for sputnik, prob in itertools.product([False, True], SPARSITIES): + mask = _create_random_sparsity( + torch.ones(B, M, M, dtype=torch.bool, device=device), prob + ) + aa = a + bb = b + if sputnik: + mask = SparseCS(mask, device) + aa = a + bb = b.transpose(-2, -1).contiguous().transpose(-2, -1) + else: + mask = mask.to_sparse() + results.append( + benchmark.Timer( + stmt="_matmul_with_mask(a, b, mask)", + globals={ + "a": aa, + "b": bb, + "mask": mask, + "_matmul_with_mask": _matmul_with_mask, + }, + label="matmul_with_mask", + sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def bench_softmax(): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + + results.extend( + [ + benchmark.Timer( + stmt="_softmax(a)", + globals={ + "a": a, + "_softmax": _softmax, + }, + label="softmax", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for sputnik, prob in itertools.product([False, True], SPARSITIES): + a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob) + if sputnik: + a = SparseCS(a, device) + else: + a = a.to_sparse() + results.append( + benchmark.Timer( + stmt="_softmax(a)", + globals={ + "a": a, + "_softmax": _softmax, + }, + label="softmax", + sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def bench_bmm(): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + b = torch.rand(B, M, K, device=device) + + results.extend( + [ + benchmark.Timer( + stmt="bmm(a, b)", + globals={ + "a": a, + "b": b, + "bmm": bmm, + }, + label="bmm", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for sputnik, prob in itertools.product([False, True], SPARSITIES): + a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob) + bb = b + if sputnik: + a = SparseCS(a, device) + bb = b + else: + a = a.to_sparse() + results.append( + benchmark.Timer( + stmt="bmm(a, b)", + globals={ + "a": a, + "b": bb, + "bmm": bmm, + }, + label="bmm", + sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +bench_sddmm() +bench_matmul_with_mask() +bench_softmax() +bench_bmm() diff --git a/pkgs/xformers/benchmarks/benchmark_indexing.py b/pkgs/xformers/benchmarks/benchmark_indexing.py new file mode 100644 index 0000000..cc23901 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_indexing.py @@ -0,0 +1,241 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import random + +import torch +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops as xops + +min_run_time = 0.5 +device = torch.device("cuda") + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES_IADD = list( + product_dict( + shape=[ + (int(48 * 0.6), 48, 1, 257 * 1536), + (int(48 * 0.6), 48, 257, 1536), + ], + scaling=[False, True], + dtype=[torch.half], + ) +) + list( + product_dict( + shape=[ + # Format: [B_src, B_inp, M, D] + (int(192 * 0.6), 192, 50, 1536), + (int(48 * 257 * 0.6), 257 * 48, 1, 1536), + (int(192 * 50 * 0.6), 192 * 50, 1, 1536), + (int(16 * 257 * 0.6), 48 * 257, 1, 1536), + ], + scaling=[False], + dtype=[torch.half], + ) +) + +CASES_ISELECT = list( + product_dict( + batches=[((48, 257), (50, 192))], + D=[1536], + keep_ratio=[0.6], + dtype=[torch.half], + ) +) + +DTYPE2STR = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float32: "f32", +} + + +def _setup_test(functions, fw: bool = False, bw: bool = False, **kwargs): + for k, benchmark_cls in functions.items(): + benchmark_object = benchmark_cls(**kwargs, bw=bw) + label = benchmark_object.label + label += "fw" if fw else "" + label += "bw" if bw else "" + + def run_one(): + if fw: + benchmark_object.fw() + if bw: + benchmark_object.bw() + + yield benchmark.Timer( + stmt="fn()", + globals={ + "fn": run_one, + }, + label=label, + description=k, + sub_label=benchmark_object.sub_label, + ) + + +class ScaledIndexAddBenchmark: + def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None: + B_src, B_out, M, D = shape + torch.manual_seed(B_out + B_src) + dtype_str = DTYPE2STR.get(dtype, dtype) + self.sub_label = f"{dtype_str} B_src={B_src}, B_out={B_out}, M={M}, D={D} s={'Y' if scaling else 'N'}" + self.label = "scaled_index_add" + self.alpha = 0.73 + + self.inp = torch.randn( + [B_out, M, D], device="cuda", dtype=dtype, requires_grad=bw + ) + self.src = torch.randn( + [B_src, M, D], device="cuda", dtype=dtype, requires_grad=bw + ) + self.scaling = ( + torch.randn([D], device="cuda", dtype=dtype, requires_grad=bw) + if scaling + else None + ) + self.index = torch.tensor( + [i for i in range(self.src.shape[0])], dtype=torch.int64, device="cuda" + ) + self.grad = torch.randn([B_out, M, D], device="cuda", dtype=dtype) + self.out = torch.Tensor() + + def fw(self) -> None: + self.out = xops.scaled_index_add( + input=self.inp.clone(), + index=self.index, + source=self.src, + scaling=self.scaling, + alpha=self.alpha, + ) + + def bw(self): + self.inp.grad = None + self.src.grad = None + if self.scaling is not None: + self.scaling.grad = None + self.out.backward(self.grad, retain_graph=True) + + +class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark): + def fw(self) -> None: + src_scaled = self.src + if self.scaling is not None: + src_scaled * self.scaling.unsqueeze(0).unsqueeze(0) + self.out = self.inp.index_add( + dim=0, + source=src_scaled, + index=self.index, + alpha=self.alpha, + ) + + +def scaled_index_add_fw(**kwargs): + yield from _setup_test( + **kwargs, + fw=True, + functions={ + "xformers": ScaledIndexAddBenchmark, + "pytorch": ScaledIndexAddBenchmarkBaseline, + }, + ) + + +def scaled_index_add_fwbw(**kwargs): + yield from _setup_test( + **kwargs, + fw=True, + bw=True, + functions={ + "xformers": ScaledIndexAddBenchmark, + "pytorch": ScaledIndexAddBenchmarkBaseline, + }, + ) + + +class IndexSelectBenchmark: + def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None: + dtype_str = DTYPE2STR.get(dtype, dtype) + self.sub_label = f"{dtype_str} D={D} batches={batches} keep={keep_ratio}" + self.label = "index_select" + srcs = [torch.randn([B, seqlen * D]) for (B, seqlen) in batches] + src = torch.cat([s.view([-1, D]) for s in srcs], dim=0).cuda().to(dtype) + src.requires_grad_(True) + + indices = [] + sources = [] + elements_i = 0 + for source_i in srcs: + index = [i for i in range(source_i.shape[0])] + random.Random(source_i.shape[0]).shuffle(index) + indices.append( + torch.tensor( + index[: int(keep_ratio * source_i.shape[0])], + dtype=torch.int64, + device="cuda", + ) + ) + sources.append( + src[ + elements_i : elements_i + source_i.shape[0] * source_i.shape[1] // D + ].reshape(source_i.shape) + ) + elements_i += source_i.shape[0] * source_i.shape[1] // D + self.indices, self.sources, self.src = indices, sources, src + self.out = torch.Tensor() + + def fw(self) -> None: + self.out = xops.index_select_cat(self.sources, self.indices) + + def bw(self): + self.src.grad = None + self.out.backward(self.out, retain_graph=True) + + +class IndexSelectBenchmarkBaseline(IndexSelectBenchmark): + def fw(self) -> None: + self.out = torch.cat( + [s[i].flatten() for s, i in zip(self.sources, self.indices)], dim=0 + ) + + +def index_select_fw(**kwargs): + yield from _setup_test( + **kwargs, + fw=True, + functions={ + "xformers": IndexSelectBenchmark, + "pytorch": IndexSelectBenchmarkBaseline, + }, + ) + + +def index_select_fwbw(**kwargs): + yield from _setup_test( + **kwargs, + fw=True, + bw=True, + functions={ + "xformers": IndexSelectBenchmark, + "pytorch": IndexSelectBenchmarkBaseline, + }, + ) + + +benchmark_main_helper(scaled_index_add_fw, CASES_IADD, min_run_time=min_run_time) +benchmark_main_helper(scaled_index_add_fwbw, CASES_IADD, min_run_time=min_run_time) +benchmark_main_helper(index_select_fw, CASES_ISELECT, min_run_time=min_run_time) +benchmark_main_helper(index_select_fwbw, CASES_ISELECT, min_run_time=min_run_time) diff --git a/pkgs/xformers/benchmarks/benchmark_mem_eff_attention.py b/pkgs/xformers/benchmarks/benchmark_mem_eff_attention.py new file mode 100644 index 0000000..0cb08ce --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -0,0 +1,316 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import random +from functools import partial + +import torch +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops +import xformers.ops.fmha as fmha + +torch.backends.cuda.matmul.allow_tf32 = False + + +def create_attn_bias( + bias_type, + batch_size: int, + num_heads: int, + q_len: int, + kv_len: int, + device, + dtype, + bias_requires_grad: bool = False, +): + NoneType = type(None) + if bias_type is NoneType: + return None + if bias_type is torch.Tensor: + attn_bias = torch.randn((1, 1, q_len, kv_len), device=device, dtype=dtype) + return attn_bias.expand(batch_size, num_heads, q_len, kv_len) + if bias_type is xformers.ops.LowerTriangularMask: + return bias_type() + assert False, f"Unsupported bias type: {bias_type}" + + +def ref_attention_bmk(q, k, v, attn_bias=None, p=0.0): + if isinstance(attn_bias, xformers.ops.AttentionMask): + attn_bias = ( + attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1])) + .to(q) + .squeeze() + ) + q = q * (1.0 / q.shape[-1] ** 0.5) + if attn_bias is None: + attn = q @ k.transpose(-2, -1) + else: + # equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v + # but faster, and is what is used in PyTorch now + attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) + attn = attn.softmax(-1) + if p > 0: + attn = torch.nn.functional.dropout(attn, p=p) + return attn @ v + + +def ref_attention(q, k, v, attn_bias, p=0.0): + assert q.ndim == 4 + B, M, H, K = q.shape + + def T(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + + if isinstance(attn_bias, torch.Tensor): + attn_bias = attn_bias.reshape(B * H, M, M) + out = ref_attention_bmk(T(q), T(k), T(v), attn_bias, p) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)) + + +min_run_time = 0.5 +device = torch.device("cuda") + +NUM_THREADS = [1] if device.type == "cuda" else [1, 40] +SHAPES = [ + # ViT + (384, 197, 1, 88), + (384, 197, 1, 80), + (384, 197, 1, 64), + (1024, 197, 1, 88), + (1024, 197, 1, 80), + (1024, 197, 1, 64), + # ViT-Huge + (32 * 16, 197, 1, 80), + (32, 197, 16, 80), + (32, 197, 16, 64), + (32, 197, 16, 128), + # ViT-Giant + (16 * 16, 197, 1, 88), + (16, 197, 16, 88), + (16, 197, 16, 64), + (16, 197, 16, 128), + # FB models + (1024, 82, 8, 64), + (150, 256, 16, 64), + (64, 256, 12, 64), + # Stable diffusion (https://github.com/huggingface/diffusers/pull/532) + (1, 4096, 16, 40), # 512x512 + (1, 16384, 16, 40), # 1024x1024 + (1, 4096, 16, 80), + (1, 16384, 16, 80), + # + bs4 + (4, 4096, 16, 40), + (4, 16384, 16, 40), + (4, 4096, 16, 80), + (4, 16384, 16, 80), + # ParlAI model + (256, 4096, 16, 64), + # Zetta B M H K + (8, 2048, 20, 128), + # LLaMa 70b - mp=8/16 + *sorted(itertools.product([1, 2], [2048, 4096, 8192], [4, 8], [128])), + *sorted( + itertools.product([16], [128, 512, 1024], [16], [16, 32, 64, 128, 160, 256]) + ), +] + +OPS = [ + (xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp), + (xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp), + # TODO: Triton is not stable: it can trigger Illegal Memory Accesses + # and its performance varies a lot between runs. + # (xformers.ops.fmha.triton.FwOp, xformers.ops.fmha.triton.BwOp), +] + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + shape=SHAPES, + num_threads=NUM_THREADS, + dropout_p=[0.0], + attn_bias_cfg=[(type(None), False)], + dtype=[torch.half], + ) +) + +# Add more cases with some variations +for c in CASES.copy(): + c = c.copy() + c.update( + random.Random(str(c["shape"])).choice( + [ + {"dropout_p": 0.3}, + {"attn_bias_cfg": (torch.Tensor, False)}, + {"attn_bias_cfg": (torch.Tensor, True)}, + {"attn_bias_cfg": (xformers.ops.LowerTriangularMask, False)}, + {"dtype": torch.bfloat16}, + {"dtype": torch.float}, + ] + ) + ) + CASES.append(c) + + +def create_tensors(shape, dtype, requires_grad=False): + B, M, H, K = shape + qkv = torch.rand( + [B, M, 3, H, K], device=device, dtype=dtype, requires_grad=requires_grad + ) + q, k, v = xformers.ops.unbind(qkv, 2) + return qkv, q, k, v + + +def mem_eff_attention_fw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype): + B, M, H, K = shape + _, q, k, v = create_tensors(shape, dtype) + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg + if attn_bias_requires_grad: + return + bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=H, + q_len=M, + kv_len=M, + device=device, + dtype=dtype, + bias_requires_grad=attn_bias_requires_grad, + ) + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + + dtype_str = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float: "f32", + }[dtype] + sub_label = ( + f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, " + f"BiasT={attn_bias_type.__name__}" + ) + + has_run = False + for fw_op, bw_op in OPS: + if not fw_op.supports(inp): + continue + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias, p)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": inp.attn_bias, + "p": dropout_p, + "fn": partial( + xformers.ops.memory_efficient_attention, op=(fw_op, bw_op) + ), + }, + label=f"attention (attn_bias={attn_bias_type})", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + has_run = True + + if not has_run: + return + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias, p)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": inp.attn_bias, + "p": dropout_p, + "fn": ref_attention, + }, + label=f"attention (attn_bias={attn_bias_type})", + description="eager", + sub_label=sub_label, + num_threads=num_threads, + ) + + +def mem_eff_attention_bw(shape, num_threads: int, attn_bias_cfg, dropout_p, dtype): + B, M, H, K = shape + qkv, q, k, v = create_tensors(shape, dtype, requires_grad=True) + + attn_bias_type, attn_bias_requires_grad = attn_bias_cfg + bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=H, + q_len=M, + kv_len=M, + device=device, + dtype=dtype, + bias_requires_grad=attn_bias_requires_grad, + ) + inp = fmha.Inputs(query=q, key=k, value=v, attn_bias=bias, p=dropout_p) + + dtype_str = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float: "f32", + }[dtype] + sub_label = ( + f"{dtype_str} {B}-{M}-{H}-{K}, p={dropout_p}, " + f"BiasT={attn_bias_type.__name__}, BiasGrad={attn_bias_requires_grad}" + ) + + has_run = False + for fw_op, bw_op in OPS: + if not fw_op.supports(inp) or not bw_op.supports(inp): + continue + has_run = True + out = xformers.ops.memory_efficient_attention( + inp.query, inp.key, inp.value, inp.attn_bias, inp.p, op=(fw_op, bw_op) + ) + grad_benchmark = torch.ones_like(q) + + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad_benchmark, + }, + label=f"attention backward (attn_bias={attn_bias_type})", + description=bw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + del out + + if not has_run: + return + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": ref_attention(q, k, v, inp.attn_bias, dropout_p), + "grad": grad_benchmark, + }, + label=f"attention backward (attn_bias={attn_bias_type})", + description="vanilla", + sub_label=sub_label, + num_threads=num_threads, + ) + + +benchmark_main_helper(mem_eff_attention_fw, CASES, min_run_time=min_run_time) +benchmark_main_helper(mem_eff_attention_bw, CASES, min_run_time=min_run_time) diff --git a/pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py b/pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py new file mode 100644 index 0000000..7f1b4ce --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_mem_eff_attn_decoder.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from functools import partial + +import torch +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops +import xformers.ops.fmha as fmha + +torch.backends.cuda.matmul.allow_tf32 = False + +# Run with +# python xformers/benchmarks/benchmark_mem_eff_attn_decoder.py --omit-baselines --quiet +# The baselines for these benchmarks are really slow because there is +# so much padding in the inputs, so there is no point running them. + + +def ref_attention_bmk(q, k, v, attn_bias=None): + if isinstance(attn_bias, xformers.ops.AttentionMask): + attn_bias = ( + attn_bias.materialize((q.shape[0], 1, q.shape[1], k.shape[1])) + .to(q) + .squeeze() + ) + q = q * (1.0 / q.shape[-1] ** 0.5) + if attn_bias is None: + attn = q @ k.transpose(-2, -1) + else: + # equivalent to (q @ k.transpose(-2, -1) + m).softmax(-1) @ v + # but faster, and is what is used in PyTorch now + attn = torch.baddbmm(attn_bias, q, k.transpose(-2, -1)) + attn = attn.softmax(-1) + return attn @ v + + +def ref_attention(q, k, v, attn_bias): + assert q.ndim == 4 + + def T(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + + out = ref_attention_bmk(T(q), T(k), T(v), attn_bias) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)) + + +min_run_time = 0.5 +device = torch.device("cuda") + +NUM_THREADS = [1] if device.type == "cuda" else [1, 40] + +OPS = [ + xformers.ops.fmha.cutlass.FwOp, + xformers.ops.fmha.decoder.FwOp, +] + +KV_SHAPES = [ + # list of n_keys, padding_length, batchsize + (2, 64, 3), + (32, 1024, 500), + (1000, 1024, 2), + (8000, 8192, 1), + (240, 256, 32), + (2048, 2 * 1024, 4), + (4096 * 2, 8 * 1024, 1), +] + +N_HEADS = [8, 16, 64] + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + kv_shape=KV_SHAPES, + n_heads=N_HEADS, + num_threads=NUM_THREADS, + multiquery=[True, False], + ) +) + + +def mem_eff_attention_decoder( + kv_shape, n_heads: int, num_threads: int, multiquery: bool +): + n_keys, padding, B = kv_shape + torch.manual_seed(42) + k_seqlen = torch.randint(1, n_keys + 1, (B,)).tolist() + K = 128 + + q = torch.rand(1, B, n_heads, K, device=device, dtype=torch.bfloat16) + if multiquery: + k = torch.rand( + 1, B * padding, 1, K, device=device, dtype=torch.bfloat16 + ).expand(1, B * padding, n_heads, K) + v = torch.rand( + 1, B * padding, 1, K, device=device, dtype=torch.bfloat16 + ).expand(1, B * padding, n_heads, K) + else: + k = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) + v = torch.rand(1, B * padding, n_heads, K, device=device, dtype=torch.bfloat16) + + bias = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens( + q_seqlen=[1] * B, + kv_seqlen=k_seqlen, + kv_padding=padding, + ) + + sub_label = f"{B}batch-{k_seqlen[0]}keys-{n_heads}heads" + if multiquery: + sub_label += "-mq" + + has_run = False + for fw_op in OPS: + inp = fmha.Inputs(q, k, v, attn_bias=bias) + if not fw_op.supports(inp): + continue + + fn = partial(xformers.ops.memory_efficient_attention_forward, op=fw_op) + + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": bias, + "fn": fn, + }, + label="attention", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + fn(q, k, v, bias) + yield benchmark.Timer( + stmt="graph.replay()", + globals={ + "graph": graph, + }, + label="cuda graphed attention", + description=fw_op.NAME, + sub_label=sub_label, + num_threads=num_threads, + ) + + has_run = True + + if not has_run: + return + + RUN_BASELINES = False + if RUN_BASELINES: + yield benchmark.Timer( + stmt="fn(q, k, v, attn_bias)", + globals={ + "q": q, + "k": k, + "v": v, + "attn_bias": bias, + "fn": ref_attention, + }, + label="attention", + description="eager", + sub_label=sub_label, + num_threads=num_threads, + ) + + +benchmark_main_helper(mem_eff_attention_decoder, CASES, min_run_time=min_run_time) diff --git a/pkgs/xformers/benchmarks/benchmark_mlp.py b/pkgs/xformers/benchmarks/benchmark_mlp.py new file mode 100644 index 0000000..a32846e --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_mlp.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import Activation +from xformers.components.feedforward import MLP, FusedMLP + +SHAPES = [ + (8, 256, 512), + (8, 512, 1024), + (4, 1024, 1024), + (2, 2048, 2048), + (1, 2048, 4096), + (1, 1024, 12288), +] + +HIDDEN_LAYER_MULTIPLIER = [4] + + +def bench_MLP(backward: bool, bias: bool, dropout: float, activation: Activation): + device = torch.device("cuda") + bw = "+bw" if backward else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + for hlm in HIDDEN_LAYER_MULTIPLIER: + fused_mlp = FusedMLP( + dim_model=K, + dropout=dropout, + activation=activation, + hidden_layer_multiplier=hlm, + bias=bias, + ).to(device=device, dtype=dtype) + + standard_mlp = MLP( + dim_model=K, + dropout=dropout, + activation=activation, + hidden_layer_multiplier=hlm, + bias=bias, + ).to(device=device, dtype=dtype) + + a = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + def mlp_standard(): + y = standard_mlp(a) + if backward: + torch.norm(y).backward() + return y + + def mlp_fused(): + y = fused_mlp(a) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase( + mlp_standard, + "standard - {} - {} bias - {} drop - fw{}".format( + activation, + "no" if not bias else "", + dropout, + "+bw" if backward else "", + ), + ), + TestCase( + mlp_fused, + "fused - {} - {} bias - {} drop - fw{}".format( + activation, + "no" if not bias else "", + dropout, + "+bw" if backward else "", + ), + ), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"{B} x {M} x {K} - {hlm}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better. BMK - mul: ", + ) + pretty_plot( + results, + title=f"MLP-{activation}-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +if __name__ == "__main__": + # Get the user requests + parser = argparse.ArgumentParser("Benchmark MLP") + parser.add_argument("-act", "--activations", nargs="+", default=[Activation.GeLU]) + parser.add_argument("-bias", "--bias", nargs="+", default=[False, True]) + parser.add_argument("-dropout", "--dropout", nargs="+", default=[0.0, 0.1]) + args = parser.parse_args() + + for bw in [False, True]: + for bias in args.bias: + for dropout in args.dropout: + for activation in args.activations: + bench_MLP( + backward=bw, + bias=bias, + dropout=float(dropout), + activation=activation, + ) diff --git a/pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py b/pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py new file mode 100644 index 0000000..2345cf2 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_multi_head_dispatch.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import torch.nn as nn +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import MultiHeadDispatch +from xformers.components.attention import ScaledDotProduct + +SHAPES = [ + (8, 384, 128), + (8, 784, 512), + (4, 1024, 768), + (4, 2048, 1024), + (2, 2048, 2048), + (2, 2048, 4096), + (2, 4096, 4096), + (1, 2048, 12288), +] + +N_HEADS = [4] + + +def bench_multihead_dispatch(backward: bool, self_attention: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + sa = " (self_attn)" if self_attention else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + for heads in N_HEADS: + xf_multi_head = MultiHeadDispatch( + dim_model=K, + residual_dropout=0.0, + num_heads=heads, + attention=ScaledDotProduct(), + bias=(True, True, True, True), + ).to(device=device, dtype=dtype) + torch_multi_head = nn.MultiheadAttention( + embed_dim=K, num_heads=heads, batch_first=True + ).to(device=device, dtype=dtype) + + q = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + if self_attention: + k = q + v = q + else: + k = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + v = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + def torch_mha(): + y, _ = torch_multi_head(query=q, key=k, value=v) + if backward: + torch.norm(y).backward() + return y + + def xformers_mha(): + y = xf_multi_head(query=q, key=k, value=v) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(torch_mha, f"torch - fw{bw}{sa}"), + TestCase(xformers_mha, f"xf - fw{bw}{sa}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"B={B}, M={M}, K={K}, N_HEADS={heads}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"MHA-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +for bw in [False, True]: + for self_attention in [False, True]: + bench_multihead_dispatch(bw, self_attention) diff --git a/pkgs/xformers/benchmarks/benchmark_nystrom_utils.py b/pkgs/xformers/benchmarks/benchmark_nystrom_utils.py new file mode 100644 index 0000000..6f4b38c --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_nystrom_utils.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch.utils import benchmark + +from xformers.components.attention.utils import iterative_pinv + +MIN_RUN_TIME = 1 +SHAPES = [[8, 8], [256, 1024], [128, 256]] +SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99] + + +def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + a = torch.softmax(a, dim=-1) + + results.extend( + [ + benchmark.Timer( + stmt=f"{inverse_fn.__name__}(a)", + globals={ + "a": a, + f"{inverse_fn.__name__}": inverse_fn, + }, + label=f"{inverse_fn.__name__}", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for prob in SPARSITIES: + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + a = a.to_sparse() + results.append( + benchmark.Timer( + stmt=f"{inverse_fn.__name__}(a)", + globals={ + "a": a, + f"{inverse_fn.__name__}": inverse_fn, + }, + label=f"{inverse_fn.__name__}", + sub_label=f"sparsity: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def iterative_pinv_analysis( + identity_tolerance: float = 1e-1, + pinv_tolerance: float = 5e-1, + max_iters: int = 30, + plot: bool = True, +): + + for i in range(1, 10): + B, M = 1, 2**i + a = torch.rand(B, M, M) + a = torch.softmax(a, dim=-1) + + for n_iter in range(1, max_iters + 1): + result = iterative_pinv(a, n_iter=n_iter) + expected = torch.linalg.pinv(a) + + result_identity = torch.matmul(a, result) + identity = torch.eye(M) + + # Default is frobenius norm. + identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1)) + inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1)) + + if (identity_error < identity_tolerance).all() or n_iter == max_iters: + print( + f"Size {M}, n_iters {n_iter}: \n\t \ + Final Error from Identity: {identity_error.item()} \n\t \ + Final Error from linalg.pinv {inverse_error.item()}" + ) + break + + +if __name__ == "__main__": + iterative_pinv_analysis() + bench_inverse(iterative_pinv) + bench_inverse(torch.linalg.pinv) diff --git a/pkgs/xformers/benchmarks/benchmark_revnet.py b/pkgs/xformers/benchmarks/benchmark_revnet.py new file mode 100644 index 0000000..8561481 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_revnet.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components.reversible import ReversibleSequence + +SHAPES = [(16384, 32), (2048, 256), (128, 4096)] + +DEPTH = [4, 32, 256] + + +def bench_revnet(backward: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, K in SHAPES: + for depth in DEPTH: + f = torch.nn.Linear(K, K).to(device=device, dtype=dtype) + g = torch.nn.Linear(K, K).to(device=device, dtype=dtype) + revseq = ReversibleSequence( + torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth) + ) + revseq = revseq.to(device=device, dtype=dtype) + + a = torch.rand( + 1, B, K, device=device, dtype=dtype, requires_grad=backward + ) + b = torch.rand( + 1, B, K * 2, device=device, dtype=dtype, requires_grad=backward + ) + + def normal_step(): + y = a + for _ in range(depth): + y = y + f(y) + y = y + g(y) + if backward: + torch.norm(y).backward() + return y + + def reversible_step(): + y = revseq(b) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(normal_step, f"residual - fw{bw}"), + TestCase(reversible_step, f"reversible - fw{bw}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"Batch={B}, Features={K}, Depth={depth}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"RevNet-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +for bw in [False, True]: + bench_revnet(bw) diff --git a/pkgs/xformers/benchmarks/benchmark_sddmm.py b/pkgs/xformers/benchmarks/benchmark_sddmm.py new file mode 100644 index 0000000..693e4a6 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_sddmm.py @@ -0,0 +1,117 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import torch +from torch.utils import benchmark + +from xformers.components.attention._sputnik_sparse import _csr_to_coo +from xformers.components.attention.core import SparseCS, _create_random_sparsity + +MIN_RUN_TIME = 0.2 + + +def _get_fn(backend): + if backend == "csr_ge": + fn = torch.ops.xformers.csr_sddmm + elif backend == "csr_sputnik": + fn = torch.ops.xformers.sddmm_sputnik + elif backend == "coo_ge": + + def fn(a, b, row_indices, row_offsets, column_indices): + row_coo, _ = _csr_to_coo( + a.shape[-2], b.shape[-2], row_offsets, column_indices + ) + return torch.ops.xformers.coo_sddmm( + a, b, row_indices, row_coo, column_indices + ) + + elif backend == "csr_to_coo": + + def fn(a, b, row_indices, row_offsets, column_indices): + row_coo, _ = _csr_to_coo( + a.shape[-2], b.shape[-2], row_offsets, column_indices + ) + return row_coo + + return fn + + +def bench_sddmm(configs): + min_run_time = MIN_RUN_TIME + + device = torch.device("cuda") + results = [] + + for (B, M, K), prob in configs: + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, M, K, device=device) + + mask = _create_random_sparsity( + torch.ones(1, M, M, dtype=torch.bool), prob, divisible_by=16 + ) + aa = a + bb = b + mask = SparseCS(mask, device) + row_indices = mask.row_indices + row_offsets = mask.row_offsets + column_indices = mask.column_indices + + for backend in ["csr_sputnik", "csr_ge", "coo_ge", "csr_to_coo"]: + + fn_str = "fn(a, b, row_indices, row_offsets, column_indices)" + fn = _get_fn(backend) + + results.append( + benchmark.Timer( + stmt=fn_str, + globals={ + "a": aa, + "b": bb, + "mask": mask, + "row_indices": row_indices, + "row_offsets": row_offsets, + "column_indices": column_indices, + "fn": fn, + }, + label="sddmm", + sub_label=f"B={B:>4d}, M={M:>4d}, K={K:>3d}, prob={prob:0.4f}", + description=backend, + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + return results + + +# batch size 32, for different layers +SWIN_T_SIZES = [(96, 3136, 32), (192, 784, 32), (384, 196, 32), (768, 49, 32)] +swin_t_config = list(zip(SWIN_T_SIZES, (0.9844, 0.9375, 0.75, 0.0))) + +# some random values +BASIC_SIZES = [(32, 1024, 32), (32, 1024, 128), (8, 4096, 32), (8, 4096, 128)] +SPARSITIES = [0.90, 0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.999] +basic_config = list(itertools.product(BASIC_SIZES, SPARSITIES)) + +# batch size 32 here +vit_sizes = [ + (192, 785, 64), # deit_small_patch8_224 + (192, 197, 64), # deit_small_patch16_224 + (384, 785, 64), # deit_base_patch8_224 + (384, 197, 64), # deit_base_patch16_224 +] +SPARSITIES = [0.70, 0.80, 0.85, 0.90, 0.93, 0.95, 0.97] +vit_config = list(itertools.product(vit_sizes, SPARSITIES)) + +results = [] + +print("Swin Transformer") +results += bench_sddmm(swin_t_config) +print("ViT") +results += bench_sddmm(vit_config) +print("Basic cases") +results += bench_sddmm(basic_config) diff --git a/pkgs/xformers/benchmarks/benchmark_swiglu.py b/pkgs/xformers/benchmarks/benchmark_swiglu.py new file mode 100644 index 0000000..ffa413a --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_swiglu.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from contextlib import nullcontext +from functools import partial +from typing import Any + +import torch +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops.swiglu_op as xsw + +min_run_time = 0.5 +device = torch.device("cuda") + +SHAPES = [ + # Format: [inp.shape[0], inp.shape[1], hidden.shape[1]] + # ViT-Giant + (9456, 1536, 2736), + (4440, 1536, 2736), + (4728, 1536, 2736), + # Some smaller shapes as well + (4728, 1536, 1024), + # GPT-3 (small) + (32768, 2048, 5632), + # Chinchilla + (32768, 8192, 22016), +] + + +# OP = xsw._SwiGLUDecomposedOp +# OP = xsw.SwiGLUFusedOp +OP = xsw.SwiGLUPackedFusedOp + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + shape=SHAPES, + dtype=[torch.bfloat16, torch.half, "autocast_half"], + bias=[True, False], + ) +) + +DTYPE2STR = { + torch.bfloat16: "b16 ", + torch.half: "f16 ", + "autocast_half": "f16.ac", +} + + +def benchmark_swiglu(shape, dtype, bias: bool): + if dtype == "autocast_half": + inp_dtype, model_dtype, autocast = torch.float, torch.float, True + else: + inp_dtype, model_dtype, autocast = dtype, dtype, False + + x = torch.randn(shape[:2], device=device, dtype=inp_dtype) + module = ( + xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias) + .to(device) + .to(model_dtype) + ) + + dtype_str = DTYPE2STR.get(dtype, dtype) + bstr = "bias" if bias else "nobi" + sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}" + + params = module._ordered_params() + + PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else "" + yield benchmark.Timer( + stmt=f"{PREFIX}fn(x, *args)", + globals={ + "x": x, + "args": params, + "fn": partial(xsw.swiglu, op=OP), + }, + label="swiglu_fw", + description=OP.NAME, + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt=f"{PREFIX}fn(x, *args)", + globals={ + "x": x, + "args": params, + "fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp), + }, + label="swiglu_fw", + description="eager", + sub_label=sub_label, + ) + + +def benchmark_swiglu_bw(shape, dtype, bias: bool): + if dtype == "autocast_half": + inp_dtype, model_dtype = torch.float, torch.float + cm: Any = partial(torch.cuda.amp.autocast, enabled=True, dtype=torch.float16) + else: + inp_dtype, model_dtype = dtype, dtype + cm = nullcontext + + x = torch.randn(shape[:2], device=device, dtype=inp_dtype) + x.requires_grad_() + module = ( + xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias) + .to(device) + .to(model_dtype) + ) + + dtype_str = DTYPE2STR.get(dtype, dtype) + bstr = "bias" if bias else "nobi" + sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}" + + params = module._ordered_params() + with cm(): + out = xsw.swiglu(x, *params, op=OP) + grad = torch.zeros_like(out) + + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="swiglu_bw", + description=OP.NAME, + sub_label=sub_label, + ) + del out + + with cm(): + out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp) + + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="swiglu_bw", + description="eager", + sub_label=sub_label, + ) + + +benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time) +benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time) diff --git a/pkgs/xformers/benchmarks/benchmark_transformer.py b/pkgs/xformers/benchmarks/benchmark_transformer.py new file mode 100644 index 0000000..a8c077b --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_transformer.py @@ -0,0 +1,155 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from functools import partial, reduce + +import timm +import torch +import torch.nn as nn +from timm.models.layers import Mlp as TimmMlp +from timm.models.vision_transformer import Attention as TimmAttention +from timm.models.vision_transformer import Block as TimmBlock +from torch.utils import benchmark +from utils import benchmark_main_helper + +import xformers.ops as xops + + +def replace_module(module: nn.Module, replace_class, factory): + if isinstance(module, replace_class): + return factory(module) + module_output = module + for name, child in module.named_children(): + module_output.add_module(name, replace_module(child, replace_class, factory)) + del module + return module_output + + +class TimmMemEffAttention(nn.Module): + def __init__(self, attn: TimmAttention, op=None): + super().__init__() + self.op = None + self.num_heads = attn.num_heads + self.scale = attn.scale + + self.qkv = attn.qkv + self.attn_drop = attn.attn_drop + self.proj = attn.proj + self.proj_drop = attn.proj_drop + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + q, k, v = xops.unbind(qkv, dim=2) + + x = xops.memory_efficient_attention(q, k, v, op=self.op).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class TimmSwiGLU(nn.Module): + def __init__(self, mlp: TimmMlp, op=None) -> None: + super().__init__() + self.fc1 = mlp.fc1 + self.swiglu = xops.SwiGLU( + in_features=mlp.fc1.in_features, + hidden_features=mlp.fc1.out_features, + bias=True, + ) + self.op = op + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.swiglu(x) + + +def mod_memeff_attn(model: nn.Module, op=None) -> nn.Module: + return replace_module(model, TimmAttention, partial(TimmMemEffAttention, op=op)) + + +def mod_mlp_to_swiglu(model: nn.Module, op=None) -> nn.Module: + def _mlp_to_swiglu(block: TimmBlock): + block.mlp = TimmSwiGLU(block.mlp, op=op) + return block + + return replace_module(model, TimmBlock, _mlp_to_swiglu) + + +mod_mlp_to_eagr_swiglu = partial(mod_mlp_to_swiglu, op=xops.SwiGLUEagerOp) +mod_mlp_to_fast_swiglu = partial(mod_mlp_to_swiglu, op=None) + + +def compose(*fns): + def compose2(f, g): + return lambda *a, **kw: f(g(*a, **kw)) + + return reduce(compose2, fns) + + +MODELS = [ + # model_name, model_factory, input_shape + ("ViT-B/16", timm.models.vit_base_patch16_224, [512, 3, 224, 224]), + ("ViT-B/8", timm.models.vit_base_patch8_224, [64, 3, 224, 224]), + ("ViT-L/16", timm.models.vit_large_patch16_224, [128, 3, 224, 224]), + ("ViT-g/14", timm.models.vit_giant_patch14_224, [32, 3, 224, 224]), +] + +MODIFIERS = [ + ["mlp", lambda x: x], + ["mlp+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)], + ["swiglu", mod_mlp_to_eagr_swiglu], + ["swiglu+fast_swiglu", mod_mlp_to_fast_swiglu], + ["swiglu+fast_swiglu+memeff", compose(mod_mlp_to_fast_swiglu, mod_memeff_attn)], +] + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + model_info=MODELS, + dtype=[torch.half], + ) +) + + +def benchmark_transformer(model_info, dtype): + device = "cuda" + + model_name, model_factory, input_shape = model_info + + inp = torch.randn(input_shape, dtype=dtype, device=device) + + for mod_name, mod_apply in MODIFIERS: + model: nn.Module = model_factory() + model = mod_apply(model).to(device).to(dtype) + + # Make sure we don't have errors + out = model(inp) + grad = out.clone() + out.backward(grad) + + yield benchmark.Timer( + stmt="model(inp).backward(grad)", + globals={ + "model": model, + "inp": inp, + "grad": grad, + }, + label="fw+bw", + description=mod_name, + sub_label=model_name, + ) + + +benchmark_main_helper(benchmark_transformer, CASES) diff --git a/pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py b/pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py new file mode 100644 index 0000000..6c9e4c5 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_triton_blocksparse.py @@ -0,0 +1,150 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# Benchmark the blocksparse operations: +# matrix multiply and softmax + +# Matmul can be of three types: +# - Dense x Dense (COO) -> Sparse +# - Sparse x Dense -> Dense +# - Dense x Sparse -> Dense + +from typing import Any, Dict + +import torch +import triton +from triton.ops.blocksparse import matmul as blocksparse_matmul + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components.attention.core import SparseCS, _matmul_with_mask + + +def bench_matmul(dtype: torch.dtype, shapes): + results: Dict[str, Any] = {} + Z, H = 1, 1 + + for M, N, K in shapes: + + modes = [(mode, block) for mode in ["sdd", "dsd"] for block in [16, 32, 64]] + + for mode, block in modes: + # create inputs + a = torch.randn((Z, H, M, K), dtype=dtype, device="cuda") + b = torch.randn((Z, H, K, N), dtype=dtype, device="cuda") + shape = { + "sdd": (M, N), + "dsd": (a.shape[2], a.shape[3]), + "dds": (b.shape[2], b.shape[3]), + }[mode] + + # Pre-sparsify everything + _layout = torch.eye(shape[0] // block, shape[1] // block, dtype=torch.long) + + # - blocksparse + layout = _layout.unsqueeze(0).expand(H, -1, -1) + a_triton = ( + triton.testing.sparsify_tensor(a, layout, block) if mode == "dsd" else a + ) + b_triton = ( + triton.testing.sparsify_tensor(b, layout, block) if mode == "dds" else b + ) + bsmm = blocksparse_matmul( + layout=layout, + block=block, + mode=mode, + device=torch.device("cuda"), + trans_a=False, + trans_b=False, + ) + + # - dense + ta = triton.testing.mask_tensor(a, layout, block) if mode == "dsd" else a + tb = triton.testing.mask_tensor(b, layout, block) if mode == "dds" else b + + # - sparse / sputnik + mask = torch.ones_like(a, dtype=torch.float, device="cuda") + mask = triton.testing.mask_tensor(mask, layout, block, value=0.0) + a_cs = a.flatten(start_dim=0, end_dim=1).to( + torch.float32 + ) # Sputnik kernels only handle fp32 + b_cs = b.flatten(start_dim=0, end_dim=1).to(torch.float32) + a_cs = a_cs.contiguous() + b_cs = b_cs.transpose(-2, -1).contiguous() + + if mode == "sdd": + b_cs = b_cs.transpose(-2, -1) + + # pyre-fixme[16]: TODO(T101400990): Pyre did not recognize the + # `SparseCS` import. + sparse_cs_mask = SparseCS( + mask.flatten(start_dim=0, end_dim=1).contiguous(), + device=torch.device("cuda"), + ) + + # The raw compute steps + op_flops = { + "sdd": 2 * Z * K * float(layout.sum()) * block * block, + "dsd": 2 * Z * N * float(layout.sum()) * block * block, + "dds": 2 * Z * M * float(layout.sum()) * block * block, + }[ + mode + ] * 1e-12 # TFlops + + def torch_step(): + return torch.matmul(ta, tb) + + def triton_step(): + return bsmm(a_triton, b_triton) + + def sparse_step(): + if mode == "sdd": + return _matmul_with_mask(a_cs, b_cs, sparse_cs_mask) + else: + return sparse_cs_mask.spmm(b_cs) + + # Run and measure, report perf in terms of TFlops + for testcase in [ + TestCase( + torch_step, + f"pytorch - {mode} - {block}: ", + ), + TestCase( + sparse_step, + f"sparse - {mode} - {block}: ", + ), + TestCase( + triton_step, + f"triton - {mode} - {block}: ", + ), + ]: + ms = triton.testing.do_bench(lambda: testcase.function())[0] + key = f"M={M}, N={N}, K={K}" + + if key not in results: + results[key] = {} + + num_flops = op_flops / ms * 1e3 # Get to TFlop per second + results[key][testcase.name] = f"{num_flops:.1f}" + print(f"{key} - {testcase.name} - {num_flops:.2f}TFlops") + + pretty_print( + results, + title="\n ------------- Type: {} -------------".format(dtype), + units="TFlops/s", + ) + + pretty_plot( + results, + title=f"Sparse/Blocksparse throughput - {dtype}", + filename=f"blocksparse_{dtype}.png", + dash_key="pytorch", + units="TFlops/s", + ) + + +shapes = [(k, k, k) for k in [128, 512, 1024, 2048, 4096]] +bench_matmul(torch.float16, shapes) +bench_matmul(torch.float32, shapes) diff --git a/pkgs/xformers/benchmarks/benchmark_triton_dropout.py b/pkgs/xformers/benchmarks/benchmark_triton_dropout.py new file mode 100644 index 0000000..e7e3bf7 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_triton_dropout.py @@ -0,0 +1,115 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict, Optional + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import Activation, build_activation +from xformers.triton import FusedDropoutBias + +SHAPES = [ + (8, 256, 512), + (8, 512, 1024), + (4, 1024, 1024), + (2, 2048, 2048), + (1, 2048, 12288), + (2, 4096, 4096), +] + +P = 0.1 + + +def to_gbs_fw(a, ms, bias): + # Read and write the full array + total = 2 * a.numel() * a.element_size() + + if bias: + # Read the bias, ideally only once + total += a.shape[-1] * a.element_size() + + return total * 1e-9 / (ms * 1e-3) + + +def bench_dropout(bias: bool, backward: bool, activation: Optional[Activation]): + device = torch.device("cuda") + + for dtype in [ + torch.float16, + torch.float32, + ]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + a = torch.rand( + (B, M, K), device=device, dtype=dtype, requires_grad=backward + ) + b = torch.rand(K, device=device, dtype=dtype, requires_grad=backward) + torch_act = build_activation(activation) + triton_dropout = FusedDropoutBias( + P, bias_shape=K if bias else None, activation=activation + ) + + def torch_step(x): + x_ = x + b if bias else x + y = torch.nn.functional.dropout(x_, P) + if activation: + y = torch_act(y) + + if backward: + y.grad = None + torch.norm(y).backward() + return y + + def triton_step(x): + y = triton_dropout(x) + if backward: + y.grad = None + torch.norm(y).backward() + return y + + for testcase in [ + TestCase( + torch_step, + "pytorch - bias: {} - fw{} - act: {}".format( + bias, "+bw" if backward else "", activation + ), + ), + TestCase( + triton_step, + "triton - bias: {} - fw{} - act: {}".format( + bias, "+bw" if backward else "", activation + ), + ), + ]: + time = triton.testing.do_bench( + lambda: testcase.function(a), grad_to_none=[a, b] + )[0] + key = f"B={B}, M={M}, K={K}" + if key not in results: + results[key] = {} + + # Record BW + bandwidth = to_gbs_fw(a, time, bias) + results[key][testcase.name] = f"{bandwidth:.1f}" + + pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s") + pretty_plot( + results, + title="Dropout-Bias-{}-FW{}-{}-Act: {}".format( + bias, "+BW" if backward else "", dtype, activation + ), + units="GB/s", + dash_key="pytorch", + ) + + +for activation in [Activation.GeLU, None, Activation.SquaredReLU]: + for bw in [True, False]: + for bias in [True, False]: + bench_dropout(bias, bw, activation) diff --git a/pkgs/xformers/benchmarks/benchmark_triton_fused_linear.py b/pkgs/xformers/benchmarks/benchmark_triton_fused_linear.py new file mode 100644 index 0000000..1b06602 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_triton_fused_linear.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict, List, Optional + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import Activation, build_activation +from xformers.triton.fused_linear_layer import FusedLinear + +SHAPES = [ + (8, 512, 256), # Batch x Seq x Embedding + (8, 512, 512), + (4, 512, 1024), + (2, 512, 2048), + (2, 512, 4096), + (2, 512, 8192), +] + +# Switch PyTorch to TF32 accumulations, Triton does that also +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def get_metrics_transform( + activation: Optional[Activation], + a: torch.Tensor, + w: torch.Tensor, + b: Optional[torch.Tensor], + backward: bool, +): + # all operations will involve a * weight. + flop = a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1) + + # optional activation on top + if activation is not None: + flop += a.numel() + + # optionally * 2 (before the bias) if backward + if backward: + flop *= 2 + + # backward will also output a gradient with respect to the bias + # which consolidates on all the activation gradient + flop += a.shape[0] * a.shape[1] * w.shape[1] + + # backward will also ouput another gradient with respect to the weight, + # which is another matmul, in between the grad_out and the inputs this time + flop += a.shape[0] * a.shape[1] * w.shape[1] * (2 * a.shape[2] - 1) + + # optional bias on top + if b is not None: + flop += b.numel() + + def metric_conversion(ms): + # Returns TFlops/second + return flop * 1e-12 / (ms * 1e-3) + + return metric_conversion + + +def bench_linear(activations: List[Optional[Activation]]): + device = torch.device("cuda") + + for dtype in [ + torch.float32, + torch.float16, + ]: + for backward in [True, False]: + + for activation in activations: + results: Dict[str, Any] = {} + + for bias in [False, True]: + for B, M, K in SHAPES: + a = torch.rand( + B, M, K, device=device, dtype=dtype, requires_grad=backward + ) + + # Pytorch linear layer + activation + torch_linear = torch.nn.Linear(K, 4 * K, bias=bias).to( + dtype=dtype, device=device + ) + torch_activation = build_activation(activation) + + # Fused layer equivalent + fused_linear = FusedLinear( + K, 4 * K, bias=bias, activation=activation + ).to(dtype=dtype, device=device) + + def torch_step(x): + y = torch_activation(torch_linear(x)) + if backward: + torch.norm(y).backward() + return y + + def triton_step(x): + y = fused_linear(x) + + if backward: + torch.norm(y).backward() + return y + + metrics_transform = get_metrics_transform( + activation, + a, + torch_linear.weight, + torch_linear.bias, + backward, + ) + + for testcase in [ + TestCase( + torch_step, + "pytorch - {} - {} bias - fw{}".format( + activation, + "no" if not bias else "", + "+bw" if backward else "", + ), + ), + TestCase( + triton_step, + "triton - {} - {} bias - fw{}".format( + activation, + "no" if not bias else "", + "+bw" if backward else "", + ), + ), + ]: + time = triton.testing.do_bench( + lambda: testcase.function(a) + )[0] + key = f"B={B}, M={M}, K={K}" + if key not in results: + results[key] = {} + + metric = metrics_transform(time) + results[key][testcase.name] = f"{metric:.1f}" + + pretty_print( + results, + title="\n --- Type: {} ---".format(dtype), + units="TFlops/s", + ) + + _type = "_fp16" if dtype == torch.float16 else "_fp32" + title = "FusedLinear" + _type + "_FW" + if backward: + title += "_BW" + title += "_" + activation.value if activation else "_none" + pretty_plot(results, title, "TFlops/s", dash_key="pytorch") + + +activations = [ac for ac in Activation] + [None] # type: ignore +bench_linear(activations) # type: ignore diff --git a/pkgs/xformers/benchmarks/benchmark_triton_layernorm.py b/pkgs/xformers/benchmarks/benchmark_triton_layernorm.py new file mode 100644 index 0000000..7e991fd --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_triton_layernorm.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.triton import FusedLayerNorm + +SHAPES = [ + (8, 256, 512), + (8, 512, 1024), + (4, 1024, 1024), + (2, 2048, 2048), + (2, 4096, 4096), + (1, 2048, 12288), +] + + +def to_gbs_fw(a, ms): + # Read and write the full array + return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) + + +def bench_layernorm(backward: bool): + device = torch.device("cuda") + + for dtype in [ + torch.float16, + torch.bfloat16, + torch.float32, + ]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=backward) + + # Pytorch layer norn + torch_layernorm = torch.nn.LayerNorm([K]).to(dtype=dtype, device=device) + + # pyre-ignore[16]: TODO(T101400990): Pyre did not recognize the + # `FusedLinearNorm` import. + # Fused layernorm equivalent + fused_layernorm = FusedLayerNorm([K]).to(dtype=dtype, device=device) + + def torch_step(x): + y = torch_layernorm(x) + if backward: + torch.norm(y).backward() + return y + + def triton_step(x): + y = fused_layernorm(x) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase( + torch_step, + "pytorch - fw{}".format("+bw" if backward else ""), + ), + TestCase( + triton_step, + "triton - fw{}".format("+bw" if backward else ""), + ), + ]: + time = triton.testing.do_bench(lambda: testcase.function(a))[0] + key = f"B={B}, M={M}, K={K}" + if key not in results: + results[key] = {} + + # Record BW + bandwidth = to_gbs_fw(a, time) + results[key][testcase.name] = f"{bandwidth:.1f}" + + pretty_print(results, title="\n --- Type: {} --- ".format(dtype), units="GB/s") + pretty_plot( + results, + title="LayerNorm-FW{}-{}".format("+BW" if backward else "", dtype), + units="GB/s", + dash_key="pytorch", + ) + + +for bw in [False, True]: + bench_layernorm(bw) diff --git a/pkgs/xformers/benchmarks/benchmark_triton_softmax.py b/pkgs/xformers/benchmarks/benchmark_triton_softmax.py new file mode 100644 index 0000000..25b2630 --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_triton_softmax.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from xformers.benchmarks.utils import TestCase, bench_functions +from xformers.triton.softmax import log_softmax as triton_log_softmax +from xformers.triton.softmax import softmax as triton_softmax + +SHAPES = [ + (8, 384, 128), + (8, 784, 512), + (4, 1024, 768), + (4, 2048, 1024), + (2, 2048, 2048), + (2, 2048, 4096), + (2, 4096, 4096), + (1, 2048, 12288), +] + + +def pytorch_fw_bw(x): + y = torch.norm(torch.softmax(x, dim=-1)) + y.backward() + + +def triton_causal_fw(x): + _ = triton_softmax(x, causal=True) + + +def triton_fw_bw(x): + y = torch.norm(triton_softmax(x)) + y.backward() + + +def triton_causal_fw_bw(x): + y = torch.norm(triton_softmax(x, causal=True)) + y.backward() + + +def pytorch_log_fw_bw(x): + y = torch.norm(torch.log_softmax(x, dim=-1)) + y.backward() + + +def triton_log_fw_bw(x): + y = torch.norm(triton_log_softmax(x)) + y.backward() + + +# Test FW +def to_gbs_fw(a, ms): + # Read and write the full array + return (2 * a.numel() * a.element_size() * 1e-9) / (ms * 1e-3) + + +def to_gbs_fwbw(a, ms): + # same as above, but we do it twice (FW and then gradient) + return 2 * to_gbs_fw(a, ms) + + +bench_functions( + [ + TestCase(lambda x: torch.softmax(x, dim=-1), "pytorch - fw"), + TestCase(triton_softmax, "triton - fw"), + TestCase(triton_causal_fw, "triton - causal - fw"), + TestCase(lambda x: torch.log_softmax(x, dim=-1), "pytorch - log - fw"), + TestCase(triton_log_softmax, "triton - log - fw"), + ], + SHAPES, + to_gbs_fw, + "GB/s", + "Softmax_Bandwidth_FW_", +) + +# Test FW+BW +bench_functions( + [ + TestCase(pytorch_fw_bw, "pytorch - fw+bw"), + TestCase(triton_fw_bw, "triton - fw+bw"), + TestCase(triton_causal_fw_bw, "triton - causal - fw+bw"), + TestCase(pytorch_log_fw_bw, "pytorch - log - fw+bw"), + TestCase(triton_log_fw_bw, "triton - log - fw+bw"), + ], + SHAPES, + to_gbs_fwbw, + "GB/s", + "Softmax_Bandwidth_FW_BW_", +) diff --git a/pkgs/xformers/benchmarks/benchmark_triton_stride_sum.py b/pkgs/xformers/benchmarks/benchmark_triton_stride_sum.py new file mode 100644 index 0000000..6fb887e --- /dev/null +++ b/pkgs/xformers/benchmarks/benchmark_triton_stride_sum.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.triton.sum_strided import sum_2d_dim_0 + +SHAPES = [ + (128, 128), + (384, 128), + (784, 512), + (1024, 768), + (2048, 1024), + (4096, 4096), +] + + +def to_gbs(a, ms): + # Read the full array, write the non-reduced dimension + return ((a.numel() + a.shape[1]) * a.element_size() * 1e-9) / (ms * 1e-3) + + +def bench_functions( + test_cases: List[TestCase], shapes, metric_transform, unit, title="" +): + device = torch.device("cuda") + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for M, N in shapes: + a = torch.rand(M, N, device=device, dtype=dtype, requires_grad=True) + + for testcase in test_cases: + time = triton.testing.do_bench(lambda: testcase.function(a))[0] + + metric = metric_transform(a, time) + + key = f"M={M}, N={N}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{metric:.1f}" + + _type = " fp16" if dtype == torch.float16 else " fp32" + + pretty_print( + results, + title=" ------------- Type: {} ------------- ".format(_type), + units=unit, + ) + + pretty_plot(results, title + _type, unit, dash_key="pytorch") + + +bench_functions( + [ + TestCase(lambda x: torch.sum(x, dim=0), "pytorch"), + TestCase(sum_2d_dim_0, "triton"), + ], + SHAPES, + to_gbs, + "GB/s", + "Strided_sum", +) diff --git a/pkgs/xformers/benchmarks/utils.py b/pkgs/xformers/benchmarks/utils.py new file mode 100644 index 0000000..a3d10d6 --- /dev/null +++ b/pkgs/xformers/benchmarks/utils.py @@ -0,0 +1,660 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +import copy +import csv +import glob +import logging +import math +import os +import tempfile +from collections import defaultdict, namedtuple +from dataclasses import replace +from typing import Any, Dict, Generator, List, Set, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +import tqdm +from torch.utils import benchmark + +sns.set() + +TestCase = namedtuple("TestCase", ["function", "name"]) + + +_triton_is_available = torch.cuda.is_available() +if _triton_is_available: + try: + import triton + except ImportError as e: + logging.warning(f"Triton is not available: {e}.\nbench_functions") + _triton_is_available = False + + +def pretty_print(results, title, units): + """Printout the contents of a dict as a human-readable and Markdown compatible array""" + print(title) + header = " Units: {:<45}".format(units) + print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) + + offset = len(header) + print( + "|-{}|".format("-" * offset) + + "".join("{}|".format("-" * 20) for _ in results.keys()) + ) + + workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()} + for v in results.values(): + for k in v.keys(): + workloads[k].append(v[k]) + + for k, w in workloads.items(): + print( + "| {0:<{offset}}|".format(k, offset=offset) + + "".join("{:<20}|".format(v) for v in w) + ) + + print("") + + +def pretty_plot( + results, title, units: str, filename=None, dash_key="", legend_loc="lower right" +): + """Graph out the contents of a dict. + Dash key means that if the result label has this key, then it will be displayed with a dash + """ + + if not filename: + filename = title + ".png" + + # Sanitize the filename + filename = ( + filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "") + ) + + # Gather all the results in "collumns" + workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()} + for v in results.values(): + for k in v.keys(): + workloads[k].append(float(v[k])) + + # Make sure that the plot is big enough + f = plt.figure() + f.set_figwidth(6) + f.set_figheight(6) + + # Display the collections + for k, v in workloads.items(): + if dash_key and dash_key in k: + plt.plot(list(results.keys()), v, "--") + else: + plt.plot(list(results.keys()), v) + + plt.title(title) + plt.legend(list(workloads.keys()), loc=legend_loc) + plt.ylabel(units) + plt.xticks(rotation=45) + + plt.savefig(filename, bbox_inches="tight") + plt.close(f) + + +if _triton_is_available: + + def bench_functions( + test_cases: List[TestCase], shapes, metric_transform, unit, title="" + ): + device = torch.device("cuda") + + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in shapes: + a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True) + + for testcase in test_cases: + time = triton.testing.do_bench(lambda: testcase.function(a))[0] + + metric = metric_transform(a, time) + + key = f"B={B}, M={M}, K={K}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{metric:.1f}" + + pretty_print( + results, + title=" ------------- Type: {} ------------- ".format(dtype), + units=unit, + ) + pretty_plot(results, title + str(dtype), unit, dash_key="pytorch") + + +def pretty_barplot(results, title, units: str, filename=None, dash_key=""): + """Graph out the contents of a dict. + Dash key means that if the result label has this key, then it will be displayed with a dash + """ + + if not filename: + filename = title + ".png" + + # Sanitize the filename + filename = ( + filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "") + ) + + xlabels = list(results.keys()) + # Gather all the results in "collumns" + workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()} + for v in results.values(): + for k in v.keys(): + workloads[k].append(float(v[k])) + + options = list(workloads.keys()) + group_len = len(options) + for key in workloads.keys(): + num_groups = len(workloads[key]) + break + group_width = group_len + 1 + + # Make sure that the plot is big enough + f = plt.figure() + f.set_figwidth(6) + f.set_figheight(6) + + for idx in range(group_len): + option = options[idx] + values = workloads[option] + xloc = np.arange(1 + idx, group_width * num_groups, group_width) + plt.bar(xloc, values, width=1, edgecolor="black") + + plt.title(title) + plt.legend(list(workloads.keys()), loc="upper right") + plt.ylabel(units) + + ax = plt.gca() + xticks_loc = np.arange( + 1 + (group_len - 1) / 2.0, group_width * num_groups, group_width + ) + ax.set_xticks(xticks_loc, xlabels) + plt.xticks(rotation=45) + + plt.setp(ax.xaxis.get_majorticklabels(), ha="right") + ax.set_axisbelow(True) + ax.yaxis.grid(color="gray", linestyle="dashed") + ax.xaxis.grid(color="gray", linestyle="dashed") + + plt.savefig(filename, bbox_inches="tight") + plt.close(f) + + +def rmf(filename: str) -> None: + """Remove a file like rm -f.""" + try: + os.remove(filename) + except FileNotFoundError: + pass + + +@contextlib.contextmanager +def temp_files_ctx(num: int) -> Generator: + """A context to get tempfiles and ensure they are cleaned up.""" + files = [tempfile.mkstemp()[1] for _ in range(num)] + + yield tuple(files) + + # temp files could have been removed, so we use rmf. + for name in files: + rmf(name) + + +META_ALGORITHM = "algorithm" +BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"] + + +# Serialize/unserialize to CSV +# We could use pkl, but resort to CSV for readability +def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]: + parts = os.path.basename(filename).split(".") + env = "" + description = "" + if len(parts) == 3: + env = parts[1] + description = parts[0] + + data = [] + with open(filename, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + if description != "" and row["description"] not in BASELINE_DESCRIPTIONS: + row["description"] = description + task_spec = benchmark.utils.common.TaskSpec( + stmt="", + setup="", + global_setup="", + label=row["label"], + sub_label=row["sub_label"], + description=row["description"], + env=env, + num_threads=int(row["num_threads"]), + ) + measurement = benchmark.utils.common.Measurement( + number_per_run=1, + raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)], + task_spec=task_spec, + ) + measurement.mem_use = float(row["mem_use_mb"]) # type: ignore + data.append( + ( + { + META_ALGORITHM: row["algorithm"] + if row["algorithm"] != "" + else None, + }, + measurement, + ) + ) + return data + + +def _benchmark_results_to_csv( + filename: str, results: List[Tuple[Dict[str, Any], Any]] +) -> None: + data = [ + { + "sub_label": r.task_spec.sub_label, + "label": r.task_spec.label, + "num_threads": r.task_spec.num_threads, + "algorithm": metadata.get(META_ALGORITHM, ""), + "description": r.task_spec.description + if r.task_spec.description in BASELINE_DESCRIPTIONS + else "", + "runtime_us": int(1000 * 1000 * r.mean), + "mem_use_mb": r.mem_use, + } + for metadata, r in results + ] + with open(filename, "w+", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys())) + writer.writeheader() + for d in data: + writer.writerow(d) + + +def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]: + """ + Returns a `benchmark.Compare` object, except that if we have runs + with different algorithms, we also add the algorithm name + in the column titles + """ + all_algorithms: Set[str] = set() + all_description: Set[str] = set() + for metadata, r in results: + algo = metadata.get(META_ALGORITHM, None) + if algo is not None: + all_algorithms.add(algo) + all_description.add(r.task_spec.description) + display_algo = len(all_algorithms) > 1 + display_descr = len(all_description) > 1 + + display_results = [] + for metadata, r in results: + algo = metadata.get(META_ALGORITHM, None) + if algo is None: + display_results.append(r) + else: + r = copy.copy(r) + description = "" + if display_descr: + description = r.task_spec.description + if display_algo: + if display_descr: + description += "[" + description += algo + if display_descr: + description += "]" + r.task_spec = replace(r.task_spec, description=description) + display_results.append(r) + return display_results + + +def _render_bar_plot(results: List[Any], store_results_folder: str) -> None: + if not results: + return + runtime: Dict[str, Dict[str, float]] = defaultdict(dict) + memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict) + all_descriptions: List[str] = [] + for r in results: + # Hacky: use a list to preserve order + if r.task_spec.description not in all_descriptions: + if r.task_spec.description in BASELINE_DESCRIPTIONS: + all_descriptions.insert(0, r.task_spec.description) + else: + all_descriptions.append(r.task_spec.description) + runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean + memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use + all_data_mem: List[Any] = [] + all_data_run: List[Any] = [] + for key, runtime_values in runtime.items(): + memory_values = memory_usage[key] + denom = memory_values.get(all_descriptions[0], math.inf) + if denom == 0: + all_data_mem.append([key] + [0] * len(all_descriptions)) + else: + all_data_mem.append( + [key] + [memory_values.get(d, 0) / denom for d in all_descriptions] + ) + all_data_run.append( + [key] + + [ + runtime_values.get(all_descriptions[0], 0) + / runtime_values.get(d, math.inf) + for d in all_descriptions + ] + ) + if all_descriptions[0] == "": + all_descriptions[0] = "baseline" + else: + all_descriptions[0] = f"{all_descriptions[0]} (baseline)" + + for data, filename, title in [ + (all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"), + ( + all_data_run, + "runtime.png", + "Runtime speedup (vs baseline, higher is better)", + ), + ]: + df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions) + df.plot( + x="Configuration", + kind="bar", + stacked=False, + title=title, + ) + plt.tight_layout() + filename_full = os.path.join(store_results_folder, filename) + plt.savefig(filename_full) + print(f"Saved plot: {filename_full}") + + +def benchmark_main_helper(benchmark_fn, cases: List[Dict[str, Any]], **kwargs) -> None: + """ + Helper function to run benchmarks. + Supports loading previous results for comparison, and saving current results to file. + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + "--fn", default=None, type=str, help="Only benchmark this function" + ) + parser.add_argument( + "--label", default=None, type=str, help="Store results to a file" + ) + parser.add_argument( + "--fail_if_regression", + action="store_true", + help="Enabled in CI to check against performance regressions", + ) + parser.add_argument( + "--compare", + default=None, + type=str, + help="Compare to previously stored benchmarks (coma separated)", + ) + parser.add_argument( + "--omit-baselines", + action="store_true", + help="Do not run the (potentially slow) baselines", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Skip intermediate results and progress bar", + ) + args = parser.parse_args() + + if args.fn is not None and args.fn != benchmark_fn.__name__: + print(f'Skipping benchmark "{benchmark_fn.__name__}"') + return + benchmark_run_and_compare( + benchmark_fn=benchmark_fn, + cases=cases, + optimized_label="optimized" if args.label is None else args.label, + fail_if_regression=args.fail_if_regression, + compare=args.compare.split(",") if args.compare is not None else [], + quiet=args.quiet, + omit_baselines=args.omit_baselines, + **kwargs, + ) + + +def benchmark_run_and_compare( + benchmark_fn, + cases: List[Dict[str, Any]], + compare: List[str], + omit_baselines: bool = False, + fail_if_regression: bool = False, + quiet: bool = False, + optimized_label: str = "optimized", + *, + min_run_time: int = 2, + atol_s: float = 30e-6, + rtol: float = 0.05, +) -> None: + SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True + results_compare_to = [] + results = [] + + store_results_folder = os.path.expanduser( + os.path.join( + os.environ.get( + "XFORMERS_BENCHMARKS_CACHE", + os.path.join("~", ".cache", "xformers", "benchmarks"), + ), + benchmark_fn.__name__, + ) + ) + + try: + env = ( + torch.cuda.get_device_name(torch.cuda.current_device()) + .replace(" ", "_") + .replace("-", "_") + .replace(".", "_") + ) + except (RuntimeError, AssertionError): # No GPU + env = "cpu" + assert ( + "." not in optimized_label + ), f"label=`{optimized_label}` should not contain dots" + assert "." not in env, f"env=`{env}` should not contain dots" + + os.makedirs(store_results_folder, exist_ok=True) + + # Load runs that we want to compare to + skip_vanilla_tasks = set() + for cmp_name in compare: + name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*" + for filename in glob.glob( + os.path.join(store_results_folder, f"{name_with_env}.csv") + ): + loaded = _benchmark_results_from_csv(filename) + for m, r in loaded: + if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE: + skip_vanilla_tasks.add( + (r.task_spec.sub_label, r.task_spec.num_threads) + ) + results_compare_to += loaded + + if not quiet: + pbar = tqdm.tqdm(cases, leave=False) + cases = pbar + for case in cases: + if quiet: + print(str(case)) + else: + pbar.write(f"====== {str(case)} ======") + try: + benchmarks_generator = benchmark_fn(**case) + except NotImplementedError: + # pbar.write(f"Skipped (NotImplementedError)") + continue + except RuntimeError as e: + if "CUDA out of memory" not in str(e): + raise + if not quiet: + pbar.write("Skipped (OOM)") + continue + + name = None + try: + for benchmark_object in benchmarks_generator: + is_optimized = ( + benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS + ) + metadata = {} + if is_optimized: + metadata[META_ALGORITHM] = benchmark_object._task_spec.description + benchmark_object._task_spec = replace( + benchmark_object._task_spec, description=optimized_label + ) + elif ( + omit_baselines + or ( + benchmark_object._task_spec.sub_label, + benchmark_object._task_spec.num_threads, + ) + in skip_vanilla_tasks + ): + continue + + memory = math.inf + try: + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + mem_begin = torch.cuda.max_memory_allocated() / 2**20 + benchmark_object._task_spec = replace( + benchmark_object._task_spec, env=env + ) + measurement = benchmark_object.blocked_autorange( + min_run_time=min_run_time + ) + torch.cuda.synchronize() + results.append((metadata, measurement)) + name = measurement.task_spec.description + memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin + measurement.mem_use = memory + except RuntimeError as e: + if "CUDA out of memory" not in str(e): + raise + if not quiet: + pbar.write("Skipped (OOM)") + finally: + del benchmark_object + if not quiet: + pbar.write(f"{name}: memory used: {memory} MB") + except RuntimeError as e: + if "CUDA out of memory" not in str(e): + raise + if not quiet: + pbar.write("Skipped (OOM)") + # Display results for benchmarks we just calculated + if name is not None and not quiet: + + def matches_current(r): + return ( + r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label + and r[1].task_spec.label == results[-1][1].task_spec.label + ) + + pbar.write( + str( + benchmark.Compare( + _finalize_results( + list(filter(matches_current, results)) + + list(filter(matches_current, results_compare_to)) + ) + ) + ) + ) + + results_for_print = _finalize_results(results + results_compare_to) + benchmark.Compare(results_for_print).print() + _render_bar_plot(results_for_print, store_results_folder) + + # Save runs to a file + if results and optimized_label is not None: + write_to_path = os.path.join( + store_results_folder, f"{optimized_label}.{env}.csv" + ) + _benchmark_results_to_csv(write_to_path, results) + print(f"Saved results to {write_to_path}") + + if fail_if_regression: + _fail_if_regressions( + results, reference=results_compare_to, atol_s=atol_s, rtol=rtol + ) + + +def _fail_if_regressions( + results: List[Any], reference: List[Any], atol_s: float, rtol: float +) -> None: + def get_measurement_id(r): + return ( + r[0].get(META_ALGORITHM, ""), + r[1].task_spec.label, + r[1].task_spec.sub_label, + r[1].task_spec.env, + ) + + id_to_result = {} + for r in results: + id_to_result[get_measurement_id(r)] = r[1] + + num_better = 0 + num_worse = 0 + num_nochange = 0 + num_unk = 0 + reference_set = set() + for ref in reference: + if ref[1].task_spec.description in BASELINE_DESCRIPTIONS: + continue + benchmark_id = get_measurement_id(ref) + if benchmark_id in reference_set: + raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}") + reference_set.add(benchmark_id) + if benchmark_id not in id_to_result: + num_unk += 1 + continue + res = id_to_result[benchmark_id] + # If significative change + if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s: + is_now_better = res.mean < ref[1].mean + if is_now_better: + num_better += 1 + else: + num_worse += 1 + cmp = "IMPROVED" if is_now_better else "REGRESS " + print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}") + else: + num_nochange += 1 + + print("Regression test summary:") + print(f" Better : {num_better}") + print(f" No change: {num_nochange}") + print(f" Worse : {num_worse}") + if num_unk > 0: + print(f" (no ref) : {num_unk}") + if num_worse > 1: + raise RuntimeError("At least one benchmark regressed!") + if num_nochange == 0: + raise RuntimeError("No reference found") diff --git a/pkgs/xformers/checkpoint.py b/pkgs/xformers/checkpoint.py new file mode 100644 index 0000000..5b4aabf --- /dev/null +++ b/pkgs/xformers/checkpoint.py @@ -0,0 +1,233 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import weakref +from collections import defaultdict +from contextlib import nullcontext + +import torch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map +from torch.utils.checkpoint import get_device_states, set_device_states + + +def _detach(x): + if isinstance(x, torch.Tensor): + return x.detach() + return x + + +class CachingTorchDispatchMode(TorchDispatchMode): + def __init__(self, policy_fn, storage): + self.policy_fn = policy_fn + self.storage = storage + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if self.policy_fn(func, *args, **kwargs): + out = func(*args, **kwargs) + out_detached = tree_map(_detach, out) + self.storage[func].append(out_detached) + return out + return func(*args, **kwargs) + + +class CachedTorchDispatchMode(TorchDispatchMode): + def __init__(self, policy_fn, storage): + self.policy_fn = policy_fn + self.storage = storage + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if self.policy_fn(func, *args, **kwargs): + # this is a basic guard if there are additional ops + # that were present in the second forward. An example + # is for ops like `detach`, which can be added if our + # policy is too loose + if self.storage[func]: + out = self.storage[func].pop(0) + return out + return func(*args, **kwargs) + + +def _get_default_policy(allow_list=None): + _default_allow_list = [ + "xformers.efficient_attention_forward_cutlass.default", + "xformers_flash.flash_fwd.default", + "aten.addmm.default", + "aten.mm.default", + ] + if allow_list is None: + allow_list = _default_allow_list + + def _default_policy(func, *args, **kwargs): + return str(func) in allow_list + + return _default_policy + + +class VerboseTorchDispatchMode(TorchDispatchMode): + def __init__(self): + self.operators = [] + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + self.operators.append(func) + return func(*args, **kwargs) + + +def list_operators(function, *args, **kwargs): + """ + Returns the list of operators used inside `function` with + *args and **kwargs + """ + verbose_mode = VerboseTorchDispatchMode() + with verbose_mode: + function(*args, **kwargs) + return verbose_mode.operators + + +def checkpoint(function, *args, preserve_rng_state=True, policy_fn=None, **kwargs): + """Checkpointining with custom policy function for selectively deciding + what to store and what to recompute + Args: + function: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. + Default: ``True`` + policy_fn(Union[List[Op], callable]): policy for deciding what to + store (instead of recompute). If it's a function, it should + be of form (func, *args, **kwargs) -> bool which indicates + if func outputs with *args and **kwargs should be stored or not. + Additionally, a list[Op] is also supported for easier cases. + The op should be in the format `torch.ops.***`, where the `***` + names of operators can be obtained with `list_operators`. + *args: Arguments to pass in to the given ``function``. + **kwargs: Keyword arguments to pass into the given ``function``. + """ + + # Requires PyTorch 1.13 at least + from torch.utils.checkpoint import _get_autocast_kwargs + + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs() + + if preserve_rng_state: + fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function. + # If they do so, we raise an error.) + had_cuda_in_fwd = False + if torch.cuda._initialized: + had_cuda_in_fwd = True + fwd_gpu_devices, fwd_gpu_states = get_device_states(*args) + + # Custom class to be able to take weak references + class Holder: + pass + + # The Holder object for each of the saved object is saved directly on the + # SavedVariable and is cleared when reset_data() is called on it. We MUST make + # sure that this is the only object having an owning reference to ensure that + # the Tensor stored in storage is deleted as soon as the corresponding SavedVariable + # data is cleared. + storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + weak_holder_list = [] + + if policy_fn is None: + policy_fn = _get_default_policy() + elif isinstance(policy_fn, list): + policy_fn = _get_default_policy(policy_fn) + else: + assert callable(policy_fn), "policy_fn should be None, list or a callable" + + temp_storage = defaultdict(list) + # assumption: grad_mode doesn't change inside function + if torch.is_grad_enabled(): + caching_mode = CachingTorchDispatchMode(policy_fn, temp_storage) + else: + caching_mode = nullcontext() + cached_mode = CachedTorchDispatchMode(policy_fn, temp_storage) + + def pack(x): + # TODO(varal7): Instead of returning abstract object, we can return things metadata (such as + # size, device, ...) to catch certain cases of undeterministic behavior of the forward + res = Holder() + weak_holder_list.append(weakref.ref(res)) + return res + + def unpack(x): + unpack_counter = 0 + if len(storage) == 0: + + def inner_pack(inner): + nonlocal unpack_counter + unpack_counter += 1 + # If the holder went out of scope, the SavedVariable is dead and so + # the value will never be read from the storage. Skip filling it. + if weak_holder_list[unpack_counter - 1]() is None: + return + # Use detach here to ensure we don't keep the temporary autograd + # graph created during the second forward + storage[weak_holder_list[unpack_counter - 1]()] = inner.detach() + return + + def inner_unpack(packed): + raise RuntimeError( + "You are calling backwards on a tensor that is never exposed. Please open an issue." + ) + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if preserve_rng_state and had_cuda_in_fwd: + rng_devices = fwd_gpu_devices + with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state): + if preserve_rng_state: + torch.set_rng_state(fwd_cpu_state) + if had_cuda_in_fwd: + set_device_states(fwd_gpu_devices, fwd_gpu_states) + + with cached_mode, torch.enable_grad(), torch.cuda.amp.autocast( + **gpu_autocast_kwargs + ), torch.cpu.amp.autocast( + **cpu_autocast_kwargs + ), torch.autograd.graph.saved_tensors_hooks( + inner_pack, inner_unpack + ): + _unused = function(*args, **kwargs) # noqa: F841 + + if x not in storage: + raise RuntimeError( + "Attempt to retrieve a tensor saved by autograd multiple times without checkpoint" + " recomputation being triggered in between, this is not currently supported. Please" + " open an issue with details on your use case so that we can prioritize adding this." + ) + + return storage[x] + + with caching_mode, torch.autograd.graph.saved_tensors_hooks(pack, unpack): + output = function(*args, **kwargs) + if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd: + # Cuda was not initialized before running the forward, so we didn't + # stash the CUDA state. + raise RuntimeError( + "PyTorch's CUDA state was initialized in the forward pass " + "of a Checkpoint, which is not allowed. Please open an issue " + "if you need this feature." + ) + + return output diff --git a/pkgs/xformers/components/__init__.py b/pkgs/xformers/components/__init__.py new file mode 100644 index 0000000..d264361 --- /dev/null +++ b/pkgs/xformers/components/__init__.py @@ -0,0 +1,77 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import fields +from pathlib import Path +from typing import Any, Dict, Union + +from xformers.utils import import_all_modules + +from .activations import Activation, build_activation # noqa +from .attention import Attention, build_attention # noqa +from .input_projection import InputProjection, InputProjectionConfig # noqa +from .multi_head_dispatch import MultiHeadDispatch # noqa +from .multi_head_dispatch import MultiHeadDispatchConfig +from .patch_embedding import PatchEmbeddingConfig # noqa +from .patch_embedding import build_patch_embedding # noqa +from .residual import NormalizationType # noqa +from .residual import PostNorm # noqa +from .residual import PreNorm # noqa +from .residual import RequiresWrappedInputs # noqa +from .residual import Residual # noqa +from .residual import ResidualNormStyle # noqa + +# automatically import any Python files in the directory +import_all_modules(str(Path(__file__).parent), "xformers.components") + + +def build_multi_head_attention( + multi_head_config: Union[MultiHeadDispatchConfig, Dict[str, Any]], +): + """Builds a multihead attention from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_attention", + "foo": "bar"}` will find a class that was registered as "my_attention" + (see :func:`register_attention`) and call .from_config on it.""" + + if not isinstance(multi_head_config, MultiHeadDispatchConfig): + # Extract the required fields + field_names = list(map(lambda x: x.name, fields(MultiHeadDispatchConfig))) + + # The missing fields get Noned + for k in field_names: + if k not in multi_head_config.keys(): + multi_head_config[k] = None + + # Could be that the attention needs to be instantiated + if not isinstance(multi_head_config["attention"], Attention): + # Convenience: fill in possible missing fields + if "num_heads" not in multi_head_config["attention"]: + multi_head_config["attention"]["num_heads"] = multi_head_config[ + "num_heads" + ] + + if "dim_model" not in multi_head_config["attention"]: + multi_head_config["attention"]["dim_model"] = multi_head_config[ + "dim_model" + ] + + if ( + "dim_features" not in multi_head_config["attention"] + or multi_head_config["attention"]["dim_features"] is None + ): + multi_head_config["attention"]["dim_features"] = ( + multi_head_config["dim_model"] // multi_head_config["num_heads"] + ) + + multi_head_config["attention"] = build_attention( + multi_head_config["attention"] + ) + + multi_head_config = MultiHeadDispatchConfig(**multi_head_config) + + return MultiHeadDispatch.from_config(multi_head_config) diff --git a/pkgs/xformers/components/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/components/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..2d5cb35 Binary files /dev/null and b/pkgs/xformers/components/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/activations.cpython-310.pyc b/pkgs/xformers/components/__pycache__/activations.cpython-310.pyc new file mode 100644 index 0000000..e7d995d Binary files /dev/null and b/pkgs/xformers/components/__pycache__/activations.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/input_projection.cpython-310.pyc b/pkgs/xformers/components/__pycache__/input_projection.cpython-310.pyc new file mode 100644 index 0000000..0a5780b Binary files /dev/null and b/pkgs/xformers/components/__pycache__/input_projection.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/multi_head_dispatch.cpython-310.pyc b/pkgs/xformers/components/__pycache__/multi_head_dispatch.cpython-310.pyc new file mode 100644 index 0000000..9738eaf Binary files /dev/null and b/pkgs/xformers/components/__pycache__/multi_head_dispatch.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/patch_embedding.cpython-310.pyc b/pkgs/xformers/components/__pycache__/patch_embedding.cpython-310.pyc new file mode 100644 index 0000000..c60a5b0 Binary files /dev/null and b/pkgs/xformers/components/__pycache__/patch_embedding.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/residual.cpython-310.pyc b/pkgs/xformers/components/__pycache__/residual.cpython-310.pyc new file mode 100644 index 0000000..28460df Binary files /dev/null and b/pkgs/xformers/components/__pycache__/residual.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/reversible.cpython-310.pyc b/pkgs/xformers/components/__pycache__/reversible.cpython-310.pyc new file mode 100644 index 0000000..1c53c1c Binary files /dev/null and b/pkgs/xformers/components/__pycache__/reversible.cpython-310.pyc differ diff --git a/pkgs/xformers/components/__pycache__/simplicial_embedding.cpython-310.pyc b/pkgs/xformers/components/__pycache__/simplicial_embedding.cpython-310.pyc new file mode 100644 index 0000000..b29bcc4 Binary files /dev/null and b/pkgs/xformers/components/__pycache__/simplicial_embedding.cpython-310.pyc differ diff --git a/pkgs/xformers/components/activations.py b/pkgs/xformers/components/activations.py new file mode 100644 index 0000000..69c6edb --- /dev/null +++ b/pkgs/xformers/components/activations.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import Enum +from typing import Optional + +import torch +from torch import nn + + +class Activation(str, Enum): + SquaredReLU = "squared_relu" + GeLU = "gelu" + LeakyReLU = "leaky_relu" + ReLU = "relu" + SmeLU = "smelu" + StarReLU = "star_relu" + + +# For unit testing / parity comparisons, probably not the fastest way +class SquaredReLU(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_ = torch.nn.functional.relu(x) + return x_ * x_ + + +class StarReLU(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_ = torch.nn.functional.relu(x) + return 0.8944 * x_ * x_ - 0.4472 + + +class SmeLU(nn.Module): + def __init__(self, beta: float = 2.0) -> None: + super().__init__() + self.beta = beta + + def forward(self, x: torch.Tensor) -> torch.Tensor: + relu = torch.where( + x >= self.beta, + x, + torch.tensor([0.0], device=x.device, dtype=x.dtype), + ) + return torch.where( + torch.abs(x) <= self.beta, + ((x + self.beta) ** 2).type_as(x) / (4.0 * self.beta), + relu, + ) + + +def build_activation(activation: Optional[Activation]): + if not activation: + return nn.Identity() + + return { + Activation.ReLU: nn.ReLU, + Activation.GeLU: nn.GELU, + Activation.LeakyReLU: nn.LeakyReLU, + Activation.SquaredReLU: SquaredReLU, + Activation.StarReLU: StarReLU, + Activation.SmeLU: SmeLU, + }[activation]() diff --git a/pkgs/xformers/components/attention/__init__.py b/pkgs/xformers/components/attention/__init__.py new file mode 100644 index 0000000..a542689 --- /dev/null +++ b/pkgs/xformers/components/attention/__init__.py @@ -0,0 +1,133 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from pathlib import Path +from typing import Any, Callable, Dict, Set, Union + +import torch + +from xformers.utils import ( + generate_matching_config, + get_registry_decorator, + import_all_modules, +) + +from ._sputnik_sparse import SparseCS +from .attention_mask import AttentionMask +from .base import Attention, AttentionConfig # noqa + +logger = logging.getLogger("xformers") + + +# CREDITS: Classy Vision registry mechanism + +ATTENTION_REGISTRY: Dict[str, Any] = {} +ATTENTION_CLASS_NAMES: Set[str] = set() + +# Arbitrary threshold for now, +# in between dense and sparse matrix algorithms for the attention mechanism +_DENSITY_THRESHOLD = 0.30 # noqa # from the sputnik paper, vs. +_USE_SPUTNIK = True + + +def build_attention(config: Union[Dict[str, Any], AttentionConfig]): + """Builds an attention from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_attention", + "foo": "bar"}` will find a class that was registered as "my_attention" + (see :func:`register_attention`) and call .from_config on it.""" + + if not isinstance(config, AttentionConfig): + try: + config_instance = generate_matching_config( + config, ATTENTION_REGISTRY[config["name"]].config + ) + except KeyError as e: + name = config["name"] + logger.warning(f"{name} not available among {ATTENTION_REGISTRY.keys()}") + raise e + else: + config_instance = config + + return ATTENTION_REGISTRY[config_instance.name].constructor.from_config( + config_instance + ) + + +"""Registers an Attention subclass. + + This decorator allows xFormers to instantiate a subclass of Attention + from a configuration file, even if the class itself is not part of the + xFormers library. To use it, apply this decorator to an Attention + subclass, like this: + + .. code-block:: python + + @dataclass + class MyConfig: + ... + + @register_attention('my_attention', MyConfig) + class MyAttention(Attention): + ... + + To instantiate an attention from a configuration file, see :func:`build_attention`.""" +register_attention: Callable[[str, Any], Callable[[Any], Any]] = get_registry_decorator( + ATTENTION_REGISTRY, ATTENTION_CLASS_NAMES, Attention, AttentionConfig +) + + +def maybe_sparsify(matrix) -> Any: + # Sparsify if that makes sense + if torch.count_nonzero(matrix).item() / matrix.numel() > _DENSITY_THRESHOLD: + # If not sparse, then AttentionMask is the reference type + return AttentionMask.from_bool(matrix) + + return sparsify(matrix) + + +def sparsify(matrix): + if _USE_SPUTNIK: + return SparseCS(matrix) + return matrix.to_sparse() + + +from .favor import FavorAttention # noqa +from .global_tokens import GlobalAttention # noqa +from .linformer import LinformerAttention # noqa +from .local import LocalAttention # noqa +from .nystrom import NystromAttention # noqa +from .ortho import OrthoFormerAttention # noqa +from .random import RandomAttention # noqa +from .scaled_dot_product import ScaledDotProduct # noqa + +__all__ = [ + "ScaledDotProduct", + "LocalAttention", + "LinformerAttention", + "NystromAttention", + "RandomAttention", + "OrthoFormerAttention", + "GlobalAttention", + "FavorAttention", + "Attention", + "AttentionMask", + "build_attention", + "register_attention", +] + +# Optionally expose the BlockSparse attention +try: + from .blocksparse import BlockSparseAttention # noqa + + __all__ += ["BlockSparseAttention"] +except ImportError: + pass + + +# automatically import any Python files in the directory +import_all_modules(str(Path(__file__).parent), "xformers.components.attention") diff --git a/pkgs/xformers/components/attention/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..00358b2 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-310.pyc new file mode 100644 index 0000000..5dfaa52 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/_sputnik_sparse.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/attention_mask.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/attention_mask.cpython-310.pyc new file mode 100644 index 0000000..c72562a Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/attention_mask.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/attention_patterns.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/attention_patterns.cpython-310.pyc new file mode 100644 index 0000000..16472d0 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/attention_patterns.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/base.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..f81db80 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/base.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/blocksparse.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/blocksparse.cpython-310.pyc new file mode 100644 index 0000000..d2464fe Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/blocksparse.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/compositional.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/compositional.cpython-310.pyc new file mode 100644 index 0000000..ce04f43 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/compositional.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/core.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/core.cpython-310.pyc new file mode 100644 index 0000000..ef89956 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/core.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/favor.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/favor.cpython-310.pyc new file mode 100644 index 0000000..1809f89 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/favor.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/fourier_mix.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/fourier_mix.cpython-310.pyc new file mode 100644 index 0000000..71ecc38 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/fourier_mix.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/global_tokens.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/global_tokens.cpython-310.pyc new file mode 100644 index 0000000..3c3d1f4 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/global_tokens.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/lambda_layer.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/lambda_layer.cpython-310.pyc new file mode 100644 index 0000000..549babf Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/lambda_layer.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/linformer.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/linformer.cpython-310.pyc new file mode 100644 index 0000000..9c26866 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/linformer.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/local.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/local.cpython-310.pyc new file mode 100644 index 0000000..72b8fba Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/local.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/nystrom.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/nystrom.cpython-310.pyc new file mode 100644 index 0000000..5af5a2d Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/nystrom.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/ortho.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/ortho.cpython-310.pyc new file mode 100644 index 0000000..efe2c4f Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/ortho.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/pooling.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/pooling.cpython-310.pyc new file mode 100644 index 0000000..ae05381 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/pooling.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/random.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/random.cpython-310.pyc new file mode 100644 index 0000000..5e7e1b3 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/random.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/scaled_dot_product.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/scaled_dot_product.cpython-310.pyc new file mode 100644 index 0000000..a042601 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/scaled_dot_product.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/sparsity_config.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/sparsity_config.cpython-310.pyc new file mode 100644 index 0000000..a12fbdb Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/sparsity_config.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/utils.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..a1edbb5 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/utils.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/__pycache__/visual.cpython-310.pyc b/pkgs/xformers/components/attention/__pycache__/visual.cpython-310.pyc new file mode 100644 index 0000000..e547ba7 Binary files /dev/null and b/pkgs/xformers/components/attention/__pycache__/visual.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/_sputnik_sparse.py b/pkgs/xformers/components/attention/_sputnik_sparse.py new file mode 100644 index 0000000..d6b92d7 --- /dev/null +++ b/pkgs/xformers/components/attention/_sputnik_sparse.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.ops import masked_matmul +from xformers.sparse import SparseCSRTensor + +# TODO: this is here for BC +from xformers.sparse.utils import _csr_to_coo, _dense_to_sparse # noqa: F401 + + +class SparseCS: + def __init__(self, matrix, device=None): + if device is None: + device = torch.device("cpu") + if matrix.ndim == 2: + matrix = matrix[None] + assert matrix.ndim == 3 + self._mat = SparseCSRTensor.from_dense(matrix).to(device) + + @property + def device(self): + return self._mat.device + + @property + def ndim(self): + return self._mat.ndim + + @property + def dtype(self): + return self._mat.dtype + + @property + def is_sparse(self): + return True + + @property + def shape(self): + return self._mat.shape[1:] + + @property + def values(self): + return self._mat.values() + + @property + def row_indices(self): + return self._mat._csr_row_indices + + @property + def column_indices(self): + return self._mat._csr_column_indices + + @property + def row_offsets(self): + return self._mat._csr_row_offsets + + @property + def _transp_info(self): + return self._mat._csr_transp_info + + @classmethod + def wrap( + cls, shape, values, row_indices, row_offsets, column_indices, _transp_info + ): + matrix = cls.__new__(cls) + _shape = (values.shape[0],) + shape + csr_matrix = SparseCSRTensor._wrap( + _shape, values, row_indices, row_offsets, column_indices, _transp_info + ) + matrix._mat = csr_matrix + return matrix + + @classmethod + def _wrap(cls, csr_matrix): + assert isinstance(csr_matrix, SparseCSRTensor) + matrix = cls.__new__(cls) + matrix._mat = csr_matrix + return matrix + + def __mul__(self, other): + assert isinstance(other, (int, float)) + return type(self)._wrap(self._mat * other) + + def __add__(self, other): + assert isinstance(other, type(self)) + return type(self)._wrap(self._mat + other._mat) + + def matmul_with_mask(self, a, b): + return type(self)._wrap(masked_matmul(a, b, self._mat)) + + def softmax(self): + out = torch.nn.functional.softmax(self._mat, -1) + return type(self)._wrap(out) + + def spmm(self, b): + out = torch.bmm(self._mat, b) + return out + + def transpose(self): + out = torch.transpose(self._mat, -2, -1) + return type(self)._wrap(out) + + def to(self, device): + assert isinstance(device, torch.device) + out = self._mat.to(device) + return type(self)._wrap(out) + + def to_dense(self): + return self._mat.to_dense() + + def logical_and(self, other: torch.Tensor): + assert not isinstance(other, SparseCS) + out = torch.logical_and(self._mat, other) + return type(self)._wrap(out) + + def __and__(self, other): + return self.logical_and(other) diff --git a/pkgs/xformers/components/attention/attention_mask.py b/pkgs/xformers/components/attention/attention_mask.py new file mode 100644 index 0000000..7006e97 --- /dev/null +++ b/pkgs/xformers/components/attention/attention_mask.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Type, TypeVar + +import torch + +Self = TypeVar("Self", bound="AttentionMask") + + +class AttentionMask: + """ + Holds an attention mask, along with a couple of helpers and attributes. + + .. note: this is an additive mask, meaning that coefficients which should be computed hold the '0.' value, + and coefficients which should be skipped hold the '-inf' value. Any other value is possible if the purpose + is to bias the attention computation for instance + + .. note: the attention mask dimensions are expected to be `[batch, to_sequence, from_sequence]`, + `[to_sequence, from_sequence]`, or anything broadcastable in between + """ + + def __init__(self, additive_mask: torch.Tensor, is_causal: bool = False): + assert additive_mask.is_floating_point(), additive_mask.dtype + assert not additive_mask.requires_grad + + if additive_mask.ndim == 2: + additive_mask = additive_mask.unsqueeze(0) + + self.values = additive_mask + self.is_causal = is_causal + self.seq_len = additive_mask.shape[1] + self.to_seq_len = additive_mask.shape[0] + + def to_bool(self) -> torch.Tensor: + """ + .. warning: we assume here that True implies that the value should be computed + """ + return self.values != float("-inf") + + @classmethod + def from_bool(cls: Type[Self], x: torch.Tensor) -> Self: + """ + Create an AttentionMask given a boolean pattern. + .. warning: we assume here that True implies that the value should be computed + """ + assert x.dtype == torch.bool + + additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) + additive_mask.masked_fill_(x, 0.0) + additive_mask.masked_fill_(~x, float("-inf")) + + return cls(additive_mask) + + @classmethod + def from_multiplicative(cls: Type[Self], x: torch.Tensor) -> Self: + """ + Create an AttentionMask given a multiplicative attention mask. + """ + assert not x.dtype == torch.bool + + additive_mask = torch.empty_like(x, dtype=torch.float, device=x.device) + x = x.bool() + + additive_mask.masked_fill_(x, 0.0) + additive_mask.masked_fill_(~x, float("-inf")) + + return cls(additive_mask) + + @classmethod + def make_causal( + cls: Type[Self], + seq_len: int, + to_seq_len: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> Self: + if not to_seq_len: + to_seq_len = seq_len + + additive_mask = torch.triu( + torch.ones(seq_len, to_seq_len, device=device, dtype=dtype) * float("-inf"), + diagonal=1, + ) + return cls(additive_mask=additive_mask, is_causal=True) + + def make_crop( + self, seq_len: int, to_seq_len: Optional[int] = None + ) -> "AttentionMask": + """ + Return a cropped attention mask, whose underlying tensor is a view of this one + """ + + if not to_seq_len: + to_seq_len = seq_len + + return AttentionMask( + self.values[:, :seq_len, :to_seq_len], is_causal=self.is_causal + ) + + def __repr__(self): + return f"AttentionMask - causal {self.is_causal} - mask " + str(self.values) + + @property + def device(self): + return self.values.device + + @property + def is_sparse(self): + return False + + @property + def ndim(self): + return len(self.values.shape) + + @property + def dtype(self): + return self.values.dtype + + @property + def shape(self): + return self.values.shape + + def __add__(self, other): + return AttentionMask(self.values + other.values, is_causal=False) + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> "AttentionMask": + assert device is None or isinstance(device, torch.device) + assert dtype is None or isinstance(dtype, torch.dtype) + assert device is not None or dtype is not None + + # Noop if we don't need to create another instance + if ((device and device == self.device) or not device) and ( + (dtype and dtype == self.dtype) or not dtype + ): + return self + + return AttentionMask(self.values.to(device=device, dtype=dtype), self.is_causal) diff --git a/pkgs/xformers/components/attention/attention_patterns.py b/pkgs/xformers/components/attention/attention_patterns.py new file mode 100644 index 0000000..9c817de --- /dev/null +++ b/pkgs/xformers/components/attention/attention_patterns.py @@ -0,0 +1,295 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from typing import List + +import numpy as np +import torch + +from xformers.components.attention.sparsity_config import ( + BigBirdSparsityConfig, + BSLongformerSparsityConfig, + FixedSparsityConfig, + VariableSparsityConfig, +) + + +# generic nd cases +def _generate_nd_grid(*sizes): + coords = [torch.arange(s) for s in sizes] + return torch.meshgrid(*coords) + + +def local_nd_distance(*sizes, p=2.0, weights=None): + if weights is None: + weights = (1,) * len(sizes) + assert len(sizes) == len(weights) + grid = _generate_nd_grid(*sizes) + grid = [i.flatten() * w for i, w in zip(grid, weights)] + grid = torch.stack(grid, dim=1).float() + d = torch.cdist(grid, grid, p=p) + return d + + +def local_nd_gaussian_distribution(*sizes, sigma=1): + d = local_nd_distance(*sizes, p=2.0) ** 2 + d = torch.exp(-0.5 * sigma ** (-2.0) * d) + return d + + +def local_nd_pattern(*sizes, distance, p=2.0): + d = local_nd_distance(*sizes, p=p) + return d < distance + + +def axial_nd_pattern(*sizes): + # axial is a special case with p=0 and distance=2 + d = local_nd_distance(*sizes, p=0) + return d < 2 + + +def random_pattern_from_probability_matrix(dist_matrix, nnz): + att = torch.zeros_like(dist_matrix, dtype=torch.bool) + # PyTorch multinomial wrongly doesn't support sampling when number of categories + # is > 2^24, arguing that it's because it's the max representable consecutive element + # in fp32 and that the kernels use float32. This is actually not true, and the kernels + # should work fine if double tensor is passed on CPU. This is a bug that was introduced + # in https://github.com/pytorch/pytorch/commit/bf04c2ca2f591d98ce57816f0ef0cd20a21bbf66 + # when unifying the checks between CPU and CUDA. For now, just fall-back to numpy + if dist_matrix.numel() > 2**24: + dist_matrix = dist_matrix.double() + dist_matrix /= dist_matrix.sum() + idxs = np.random.choice( + dist_matrix.numel(), nnz, p=dist_matrix.flatten(), replace=False + ) + idxs = torch.as_tensor(idxs) + else: + idxs = torch.multinomial(dist_matrix.flatten(), nnz, replacement=False) + att.view(-1)[idxs] = True + return att + + +def global_token_pattern(attention_query_mask: torch.Tensor) -> torch.Tensor: + assert attention_query_mask.ndim == 1 + assert attention_query_mask.dtype == torch.bool + attention_query_mask = attention_query_mask[None, :] + mask = attention_query_mask | attention_query_mask.transpose(1, 0) + return mask + + +def random_pattern(attn_size: int, sparsity: float) -> torch.Tensor: + assert 0 < sparsity < 1 + mask = torch.rand(attn_size, attn_size) > sparsity + return mask + + +# 1d-specific cases +def local_1d_pattern(attn_size: int, window_size: int) -> torch.Tensor: + assert ( + window_size % 2 == 1 + ), "The window size is assumed to be odd (counts self-attention + 2 wings)" + h_win_size = window_size // 2 + 1 + return local_nd_pattern(attn_size, distance=h_win_size, p=1.0) + + +def causal_1d_pattern(attn_size: int) -> torch.Tensor: + mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool)) + return mask + + +# 2d-specific cases +def horizontal_axial_2d_distance(H, W, p=2.0): + d = local_nd_distance(H, W, p=p, weights=(1, 0)) + return d + + +def vertical_axial_2d_distance(H, W, p=2.0): + d = local_nd_distance(H, W, p=p, weights=(0, 1)) + return d + + +def local_2d_distance(H, W, p=2.0): + return local_nd_distance(H, W, p=p) + + +def local_2d_gausian_distribution(H, W, sigma=1): + return local_nd_gaussian_distribution(H, W, sigma=sigma) + + +def local_2d_pattern(H, W, distance, p=2.0): + return local_nd_pattern(H, W, distance=distance, p=p) + + +def axial_2d_pattern(H, W): + return axial_nd_pattern(H, W) + + +def swin_attention_pattern(H, W, window_size, shift_size=0): + assert H % window_size == 0 + assert W % window_size == 0 + assert 0 <= shift_size < window_size, "shift_size must in 0-window_size" + + # input grid + i, j = _generate_nd_grid(H, W) + i, j = i + 0.5, j + 0.5 + + # anchors grid + # if shift is present, add extra element to the grid + # to account for the uneven partitioning + extra = int(shift_size % window_size != 0) + grid_h = H // window_size + extra + grid_w = W // window_size + extra + + ii, jj = _generate_nd_grid(grid_h, grid_w) + # convert shift to be compatible with the paper representation + s = (-shift_size) % window_size + offset = window_size / 2 - s + ii = ii * window_size + offset + jj = jj * window_size + offset + + input_coords = torch.stack([i.flatten(), j.flatten()], 1).float() + anchors_coords = torch.stack([ii.flatten(), jj.flatten()], 1).float() + + anchor_id = torch.cdist(input_coords, anchors_coords, p=2).argmin(1) + mask = anchor_id[:, None] == anchor_id[None, :] + return mask + + +def dilated_2d_pattern(H, W, k=2): + """ + Returns a 2d pattern that samples 1 every k elements in the attention mask. + Can be seen as a form of downsampling, where every pixel attends to a downsampled + version of the input. + """ + d_h = local_nd_distance(H, W, p=1, weights=(1, 0)) + d_w = local_nd_distance(H, W, p=1, weights=(0, 1)) + d = (d_h.floor() % k == 0) & (d_w.floor() % k == 0) + return d + + +# Block sparse utils +def block_sparsify_tensor(x, mask, block_size): + """ + Block sparsify a tensor, given a mask and block size + """ + ret = torch.empty( + (x.size(0), mask.sum(), block_size, block_size), dtype=x.dtype, device=x.device + ) + + for idx, (h, i, j) in enumerate(zip(*mask.nonzero(as_tuple=True))): + ret[:, idx, :, :] = x[ + :, + h, + i * block_size : (i + 1) * block_size, + j * block_size : (j + 1) * block_size, + ] + return ret + + +def pattern_to_layout(mask: torch.Tensor, block_size: int) -> torch.Tensor: + r""" + Given a mask pattern and blocksize, return the corresponding layout + which makes sure that all the positives in the mask are covered + """ + assert mask.ndim >= 2, "We're expecting [Heads, Seq, Seq] or [Seq, Seq]" + _should_squeeze = False + + if mask.ndim == 2: + mask = mask.unsqueeze(0) + _should_squeeze = True + + assert ( + mask.shape[1] % block_size == 0 and mask.shape[2] % block_size == 0 + ), "We're only handling masks divisible by block_size" + + # Now mark the mask + layout = torch.nn.functional.max_pool2d( + mask.to(torch.float), kernel_size=block_size, stride=block_size + ) + layout = layout.to(torch.long) + + if _should_squeeze: + layout.squeeze_(0) + + return layout + + +def alibi_pattern(threshold: float, mask_shape: torch.Size) -> torch.Tensor: + r""" + Use the additive bias computation from ALiBi_ to generate a mask. + Note that this mask can in turn be used to generate a blocksparse attention computation layout + + .. note: mask_shape is expected to hold the [heads, seq, seq] dimensions + + .. _ALiBi: https://arxiv.org/pdf/2108.12409.pdf + """ + + # CREDITS: code snippet from Ofir Press, one of the authors + + def get_slopes(n: int): + def get_slopes_power_of_2(n: int) -> List[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + # In the paper, we only train models that have 2^a heads for some a. This function has + # some good properties that only occur when the input is a power of 2. To maintain that even + # when the number of heads is not a power of 2, we use this workaround. + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + maxpos = mask_shape[1] + attn_heads = mask_shape[0] + slopes = torch.Tensor(get_slopes(attn_heads)) + + # In the next line, the part after the * is what constructs the diagonal matrix + # (right matrix in Figure 3 in the paper). + # If you run it you'll see that it doesn't exactly print out the same matrix as we have in Figure 3, + # but one where all rows are identical. + # This works because the softmax operation is invariant to translation, + # and our bias functions are always linear. + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(maxpos).unsqueeze( + 0 + ).unsqueeze(0).expand(attn_heads, -1, -1) + alibi = alibi.view(attn_heads, 1, maxpos) + + # Now threshold arbitrarily, report the mask + return alibi < threshold + + +def quick_fixed_layout(num_heads: int, block_size: int, seq_len: int): + config = FixedSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def quick_variable_layout(num_heads: int, block_size: int, seq_len: int): + config = VariableSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def quick_bigbird_layout(num_heads: int, block_size: int, seq_len: int): + config = BigBirdSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def quick_bslongformer_layout(num_heads: int, block_size: int, seq_len: int): + config = BSLongformerSparsityConfig(num_heads=num_heads, block_size=block_size) + return config.make_layout(seq_len) + + +def layout_to_pattern(layout: torch.Tensor, block_size: int): + r""" + create a pattern of shape [heads, seq, seq] out of a blocksparse + layout of shape [heads, seq/block_size, seq/block_size] + """ + return torch.kron(layout, torch.ones(block_size, block_size)) diff --git a/pkgs/xformers/components/attention/base.py b/pkgs/xformers/components/attention/base.py new file mode 100644 index 0000000..376cd54 --- /dev/null +++ b/pkgs/xformers/components/attention/base.py @@ -0,0 +1,93 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABCMeta, abstractmethod +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch +import torch.nn as nn + +from xformers.components.attention import AttentionMask + + +@dataclass +class AttentionConfig: + """Parameters required for all Attentions. + Can accept and store extra parameters. + """ + + name: str # the registered name for this attention mechanism + dropout: float # dropout probability + + +Self = TypeVar("Self", bound="Attention") + + +# Define the common interface, every attention block needs to derive from it +class Attention(nn.Module, metaclass=ABCMeta): + r"""The base Attention mechanism, which is typically a sub-part of the multi-head attention""" + + _causal_mask: Optional[AttentionMask] = None + + @abstractmethod + def __init__(self, dropout: Optional[float] = None, *args, **kwargs): + super().__init__() + + # Requires the inputs to be projected + self.requires_input_projection = True + + # Whether the head dimension needs to be present (if not it can be folded into the batch dimension) + self.requires_head_dimension = False + + # key padding mask and attention mask must be passed in as separate arguments instead of a merged attention mask + self.requires_separate_masks = False + + # Requires that K and Q have the same sequence length + self.requires_same_k_q_dimensions = False + + # Whether the attention owns the single head/multihead mechanism + # so that the MHA wrapper should skip it + self.requires_skip_multi_head = False + + # This attention requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = False + + # Whether this attention mechanism supports attention masks + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + @classmethod + def from_config(cls: Type[Self], config: AttentionConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) + + @abstractmethod + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def _maybe_pad_sequence(x: torch.Tensor, mask: torch.Tensor): + """ + If the sequence is shorter than the mask, return a padded view + """ + if x.shape[-2] != mask.shape[-1]: + assert x.shape[-2] < mask.shape[-1], ( + "Sequence is bigger than the provided mask, cannot infer what to do with it." + " Please update your attention mask" + ) + + pad_size = (0, 0, 0, mask.shape[-1] - x.shape[-2], 0, 0) + return torch.nn.functional.pad(x, pad_size, mode="constant", value=0.0) + + return x diff --git a/pkgs/xformers/components/attention/blocksparse.py b/pkgs/xformers/components/attention/blocksparse.py new file mode 100644 index 0000000..f099628 --- /dev/null +++ b/pkgs/xformers/components/attention/blocksparse.py @@ -0,0 +1,190 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import math +from dataclasses import dataclass + +import torch + +from xformers import _is_triton_available +from xformers.components.attention import Attention, AttentionConfig, register_attention + +logger = logging.getLogger("xformers") + + +_is_blocksparse_available = _is_triton_available() + + +if _is_blocksparse_available: + from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore + from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore + + from xformers.triton.utils import gpu_capabilities_older_than_70 + + # Blocksparse requires Tensor cores + if gpu_capabilities_older_than_70(): + logger.warning( + "Blocksparse is not available: the current GPU does not expose Tensor cores" + ) + _is_blocksparse_available = False + + +if _is_blocksparse_available: + + @dataclass + class BlockSparseAttentionConfig(AttentionConfig): + layout: torch.Tensor # The dimensions of the random features + block_size: int + dropout: float + num_heads: int + + @register_attention("blocksparse", BlockSparseAttentionConfig) + class BlockSparseAttention(Attention): + r""" + Thin wrap over the Triton blocksparse computations. The sparsity pattern is determined through the layout. + + .. warning: the layout is assumed to have the dimensions [heads, seq, seq]. + If some dimensions are missing, we assume that the same layout is to be used across heads. + + .. warning: for now, the sequence (context) length has to be a power of two. This constraint could + be relaxed in the future. + + .. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks. + It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit. + + """ + + def __init__( + self, + layout: torch.Tensor, + block_size: int = 16, + dropout: float = 0.0, + num_heads: int = 1, # optional, used to adapt the layout if in need + causal: bool = False, + *args, + **kwargs, + ): + if layout.dim() == 2: + logger.warning( + "The layout passed is lacking a head dimension and a batch dimension" + ) + logger.warning( + "Now assuming that the same layout is to be used across all heads" + ) + layout = layout.unsqueeze(0).expand(num_heads, -1, -1) + logger.warning(f"New layout dimensions: {layout.shape}") + + assert block_size in ( + 16, + 32, + 64, + 128, + ), "Only block sizes in [16, 32, 64, 128] are supported" + + super().__init__() + + self.causal = causal + + self.attn_drop = torch.nn.Dropout(dropout, inplace=False) + + # Pure blocksparse data + self.layout = layout + self.block_size = block_size + + # make sure that the head dimension is not folded down with the batch + self.requires_head_dimension = True + + # key padding mask and attention mask must be passed in separately + self.requires_same_k_q_dimensions = True + + # The underlying triton op does not support per element attention mask + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + def create_triton_kernels(self, device): + # blocksparse operators + self.sparse_dot_sdd = blocksparse_matmul( + self.layout, + self.block_size, + "sdd", + trans_a=False, + trans_b=True, + device=device, + ) + + self.sparse_dot_dsd = blocksparse_matmul( + self.layout, + self.block_size, + "dsd", + trans_a=False, + trans_b=False, + device=device, + ) + + self.sparse_softmax = blocksparse_softmax( + self.layout, + self.block_size, + device=device, + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float = 1.0, + *args, + **kwargs, + ) -> torch.Tensor: + assert ( + "att_mask" not in kwargs.keys() and "att_mask" not in args + ), "This attention does not support an attention mask, but you can specify causality." + + r""" + A thin wrap around the Triton blockparse attention operation + + .. note: Per element attention mask is not supported, but you can specify causality + """ + + # Delayed triton init, to make sure that we get the right device + # Infer device from query + if not hasattr(self, "sparse_dot_sdd"): + self.create_triton_kernels(q.device) + + assert ( + q.shape[-2] == k.shape[-2] + ), "Blocksparse requires the same dimensions for K and Q for now" + + assert ( + q.shape[-2] == self.layout.shape[-2] * self.block_size + ), "Actual sequence size and layout are inconsistent" + assert ( + k.shape[-2] == self.layout.shape[-2] * self.block_size + ), "Actual sequence size and layout are inconsistent" + + assert ( + q.shape[-2] % self.block_size + ) == 0, "Sequence length {} must be a multiple of block size {}".format( + q.shape[-2], self.block_size + ) + + # Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S) + # When the computations are block sparse, the matrix types change along the way: + # - (sparse) attention matrix = (dense) Kt * (dense) Q + q = q / math.sqrt(q.size(-1)) + sparse_att_mat = self.sparse_dot_sdd(q, k) + + # - softmax on the sparse attention matrix + sparse_att_mat = self.sparse_softmax( + sparse_att_mat, scale=scale, is_causal=self.causal + ) + + sparse_att_mat = self.attn_drop(sparse_att_mat) + + # - then (dense) attention is (sparse) attention matrix * dense (value) + a = self.sparse_dot_dsd(sparse_att_mat, v) + return a diff --git a/pkgs/xformers/components/attention/compositional.py b/pkgs/xformers/components/attention/compositional.py new file mode 100644 index 0000000..a06053c --- /dev/null +++ b/pkgs/xformers/components/attention/compositional.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# Credits: this is heavily inspired by the official implementation, present in +# https://github.com/sarthmit/Compositional-Attention +# Original author: Sarthak Mittal + +# This is a simplified version, for the sake of clarity, and because some features could be exposed later +# via the library directly. +# In particular, code paths for TPUs, quantization and gumbel softmax have been removed +# We're also following the same dimension ordering as in the rest of the xformers library +# which is to say [Batch, Sequence, Embedding] wherever possible + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) +from xformers.components.attention.core import _softmax +from xformers.components.input_projection import InputProjection, InputProjectionConfig + + +def _either_or(a: Optional[int], b: int) -> int: + return a if a is not None else b + + +@dataclass +class CompositionalAttentionConfig(AttentionConfig): + dim_model: int + num_heads: int + dim_attn: Optional[int] = None + num_rules: Optional[int] = None + dim_key: Optional[int] = None + dim_value: Optional[int] = None + dim_selection: Optional[int] = None + dropout: float + qk_rule: bool = False + nonlinear: bool = False + q_compose: bool = False + bias: bool = True + causal: Optional[bool] = False + in_proj_container: Optional[InputProjection] = None + use_separate_proj_weight: Optional[bool] = False + + +@register_attention("compositional", CompositionalAttentionConfig) +class CompositionalAttention(Attention): + """Compositional Attention, as proposed in + "Compositional Attention: Disentangling search and retrieval"_, S. Mittal et al. + + A key insight from this proposal is that the attention mechanism can be conceived as two steps: + a search and a retrieval operation. When queried, the model can search for the most relevant information + (Softmax(QKt)), then retrieve information given the Value. + + Contrary to the original attention proposal, which does not consider interactions in between heads, + the compositional attention will consider all possible interactions and softmax over that dimension, + so that the information retrieved covers the most relevant dimensions. The number of heads and rules to + use is thus typically smaller than for a comparable traditional Transformer, and asking for the same number of heads + may not fit in memory. + + Args: + dim_model: dimension of the incoming latent space + num_heads: number of heads *for the search operation* + dim_attn: dimension (embedding) of the attention + num_rules: number of rules to consider *for the retrieval operation* + dim_selection: dimension of the scoring/selection space for the retrievals + dim_key, dim_value: dimensions of K and V, if different from Q + dropout: attention dropout probability + qk_rule: QK product will drive the retrieval process + nonlinear: use a non linear method to score the retrievals + bias: use bias in the initial projection step + causal: causal computations (attend to the past only) + + _"Compositional Attention: Disentangling search and retrieval": https://arxiv.org/pdf/2110.09419v1.pdf + """ + + def __init__( + self, + dim_model: int, + num_heads: int, + dim_attn: Optional[int] = None, + num_rules: Optional[int] = None, + dim_selection: Optional[int] = None, + dim_key: Optional[int] = None, + dim_value: Optional[int] = None, + dropout=0.0, + qk_rule=False, + nonlinear=False, + q_compose=False, + in_proj_container: Optional[InputProjection] = None, + use_separate_proj_weight: Optional[bool] = False, + bias=True, + causal=False, + *_, + **__, + ): + super().__init__() + + # Define the inherited flags + self.requires_skip_multi_head = ( + True # This attention owns the multi-head mechanism + ) + + # Handle defaults / undefined values + self.dim_model = dim_model + num_rules = _either_or(num_rules, num_heads) + dim_selection = _either_or(dim_selection, dim_model // num_heads) + + # All the initial definition plumbing + dim_attn = _either_or(dim_attn, dim_model) + dim_key = _either_or(dim_key, dim_model) + dim_value = _either_or(dim_value, dim_model) + + self.in_proj_container = ( + in_proj_container + if in_proj_container is not None + else InputProjection( + query_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias), + key_proj_params=InputProjectionConfig(dim_model, dim_key, bias=bias) + if use_separate_proj_weight + else None, + value_proj_params=InputProjectionConfig(dim_model, dim_value, bias=bias) + if use_separate_proj_weight + else None, + ) + ) + + self.num_heads = num_heads + self.num_rules = num_rules + self.qk_rule = qk_rule + self.dim_selection = dim_selection + self.nonlinear = nonlinear + self.q_compose = q_compose + + self.dropout_module = nn.Dropout(dropout) + self.dim_head = dim_model // num_heads + self.value_dim = dim_attn // num_rules + + assert ( + self.value_dim * num_rules == dim_attn + ), "value_dim must be divisible by num_rules" + + self.scaling = self.dim_head**-0.5 + self.scaling_values = self.dim_selection**-0.5 + + self.out_proj = nn.Linear(self.num_heads * self.value_dim, dim_model, bias=bias) + + if self.qk_rule: + self.value_k = nn.Linear(self.value_dim, self.dim_selection, bias=bias) + if self.q_compose: + self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) + else: + self.value_q = nn.Linear( + dim_model, self.dim_selection * self.num_heads, bias=bias + ) + else: + if self.q_compose: + self.value_q = nn.Linear(self.dim_head, self.dim_selection, bias=bias) + else: + self.value_q = nn.Linear( + dim_model, self.dim_selection * self.num_heads, bias=bias + ) + if self.nonlinear: + self.score_network: nn.Module = nn.Sequential( + nn.Linear( + self.dim_selection + self.value_dim, + self.dim_selection, + bias=bias, + ), + nn.ReLU(), + nn.Linear(self.dim_selection, 1, bias=bias), + ) + else: + self.score_network = nn.Linear( + self.dim_selection + self.value_dim, 1, bias=bias + ) + + self.causal = causal + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + self._reset_parameters() + + def _reset_parameters(self): + # NOTE: in_proj_container is already initialized + + if self.qk_rule: + nn.init.xavier_uniform_(self.value_k.weight, gain=1 / math.sqrt(2)) + nn.init.xavier_uniform_(self.value_q.weight, gain=1 / math.sqrt(2)) + else: + nn.init.xavier_uniform_(self.value_q.weight) + if self.nonlinear: + nn.init.xavier_uniform_(self.score_network[0].weight) + nn.init.xavier_uniform_(self.score_network[2].weight) + else: + nn.init.xavier_uniform_(self.score_network.weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.out_proj.bias is not None: + nn.init.constant_(self.out_proj.bias, 0.0) + + def forward( + self, + q: Tensor, + k: Tensor, + v: Tensor, + att_mask: Optional[Tensor] = None, + *args, + **kwargs, + ) -> Tensor: + """ + Input shape: Time x Batch x Channel + + Args: + att_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + """ + + B, Sq, E = q.shape + _, Sk, _ = k.shape + + assert E == self.dim_model + + # First define projected query/key/values + # We keep the projected and original tensors in flight, + # depending on the options the original values could be reused + q_unprojected = q + q, k, v = self.in_proj_container(query=q, key=k, value=v) + q *= self.scaling + + # Init causal mask if needed, now that we know the context length + if self.causal and ( + self._causal_mask is None or self._causal_mask.shape[0] != Sk + ): + self._causal_mask = AttentionMask.make_causal(Sq, Sq, device=q.device) + + # Convenience, create an attention mask if a tensor was passed + # This sanitizes different mask types being passed, from now on it's additive + if isinstance(att_mask, torch.Tensor): + # By default we don't know of the causality, and a check would be expensive + att_mask_additive: Optional[AttentionMask] = ( + AttentionMask.from_bool(att_mask) + if att_mask.dtype == torch.bool + else AttentionMask(att_mask, is_causal=False) + ) + else: + att_mask_additive = None + + # Handle the attention and key padding masks + if self._causal_mask is not None: + # Optionally add the causal mask + if att_mask_additive is not None: + att_mask_additive += self._causal_mask + else: + att_mask_additive = self._causal_mask + + # Flatten the heads or the rules + q = ( + q.view(B, Sq, self.num_heads, self.dim_head) + .movedim(2, 1) + .flatten(0, 1) # [B * num_heads, Sq, dim_head] + ) + k = ( + k.view(B, Sk, self.num_heads, self.dim_head).movedim(2, 1).flatten(0, 1) + ) # [B * num_heads, Sk, dim_head] + v = v.view(B, -1, self.num_rules, self.value_dim).movedim(2, 1).flatten(0, 1) + + # Compute the search: Softmax(QKt) + attn_weights = torch.bmm(q, k.transpose(1, 2)) # [B * self.num_heads, Sq, Sk] + + if att_mask_additive is not None: + attn_weights += att_mask_additive.values + + attn_weights = _softmax(attn_weights, causal=self.causal) + + attn_weights = attn_weights.view(B, self.num_heads, Sq, Sk) + attn_probs = self.dropout_module(attn_weights) + + # Now compute the information retrieval + # keep all the heads in flight, we'll score the different possibilities + # - compute all the possible retrievals + v = v.view(B, 1, self.num_rules, Sk, self.value_dim) + attn_probs = attn_probs.unsqueeze(2) + attn = torch.matmul(attn_probs, v).view( + B, self.num_heads, self.num_rules, Sq, self.value_dim + ) + + attn = attn.movedim(3, 1) # [B, Sq, H, Rules, Values] + + # - search the most appropriate retrieval among all the values + if self.q_compose: + v_q = self.value_q(q.transpose(0, 1)).view( + B, Sq, self.num_heads, 1, self.dim_selection + ) + else: + v_q = self.value_q(q_unprojected).view( + B, Sq, self.num_heads, 1, self.dim_selection + ) + + if self.qk_rule: + v_q *= self.scaling_values + v_k = ( + self.value_k(attn) + .view(B, Sq, self.num_heads, self.num_rules, self.dim_selection) + .transpose(4, 3) + .contiguous() + ) + v_score = torch.matmul(v_q, v_k).view( + B, Sq, self.num_heads, self.num_rules, 1 + ) + else: + v_q = v_q.expand(-1, -1, -1, self.num_rules, -1) + v_in = torch.cat([attn, v_q], dim=-1) + v_score = self.score_network(v_in).view( + B, Sq, self.num_heads, self.num_rules, 1 + ) + + v_score = F.softmax(v_score, dim=3) + + # - extracted values are the original attention (inc. all the values) weighted by value score + attn = (attn * v_score).sum(dim=3).view(B, Sq, self.num_heads * self.value_dim) + + # Final attention projection, same as other mechanisms + attn = self.out_proj(attn) + + return attn diff --git a/pkgs/xformers/components/attention/core.py b/pkgs/xformers/components/attention/core.py new file mode 100644 index 0000000..0ab54cb --- /dev/null +++ b/pkgs/xformers/components/attention/core.py @@ -0,0 +1,342 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import math +from contextlib import nullcontext +from functools import lru_cache +from typing import Optional, Union + +import torch + +from xformers import _has_cpp_library, _is_triton_available +from xformers.components.attention.attention_mask import AttentionMask + +if _has_cpp_library: + from ._sputnik_sparse import SparseCS + +if _is_triton_available(): + from xformers.triton.softmax import softmax as triton_softmax + from xformers.triton.utils import gpu_capabilities_older_than_70 + +_is_blocksparse_available = ( + _is_triton_available() and not gpu_capabilities_older_than_70() +) + +if _is_blocksparse_available: + from xformers.components.attention.blocksparse import BlockSparseAttention + + +logger = logging.getLogger("xformers") + + +def _create_random_sparsity(matrix, sparsity, divisible_by=4): + assert matrix.ndim == 3 + keep = torch.rand_like(matrix[0], dtype=torch.float32) > sparsity + nonzero = torch.nonzero(keep) + nnz = nonzero.shape[0] + # NOTE: need to make it a multiple of 4 for sputnik + nonzero = nonzero[: (nnz - nnz % divisible_by)] + i, j = nonzero.unbind(1) + output = torch.zeros_like(matrix) + bdim = torch.arange(matrix.shape[0], device=matrix.device)[:, None] + output[bdim, i, j] = matrix[bdim, i, j] + return output + + +def _broadcast_batch(mask, batch_size): + if mask.ndim == 3: + return mask + assert mask.ndim == 2 + + mask = mask.coalesce() + values = mask.values() + indices = mask.indices() + nnz = len(values) + # strategy: repeat the indices and append the extra batch dimension to the indices + indices = indices.repeat(1, batch_size) + # now create the batch indices + batch_indices = torch.arange(batch_size, device=indices.device) + batch_indices = batch_indices[:, None].expand(batch_size, nnz).flatten() + + # put them together + indices = torch.cat([batch_indices[None, :], indices], dim=0) + + # now repeat the values + values = values.repeat(batch_size) + + size = (batch_size,) + mask.shape + + return torch.sparse_coo_tensor(indices, values, size) + + +def _matmul_with_mask( + a: torch.Tensor, + b: torch.Tensor, + mask: Optional[Union[torch.Tensor, "SparseCS"]], +) -> torch.Tensor: + if mask is None: + return a @ b + + if _has_cpp_library and mask.dtype == torch.bool: + if isinstance(mask, SparseCS): + return mask.matmul_with_mask(a, b) + if mask.is_sparse: + # perform broadcasting if needed + mask = _broadcast_batch(mask, a.shape[0]) + + # coalesced is not implemented for bool tensors, so need to cast + mask = mask.to(dtype=a.dtype) # type: ignore # mypy is missing the catch above + + return torch.ops.xformers.matmul_with_mask(a, b, mask) + + # Non optimized codepath + if _has_cpp_library: + assert not isinstance(mask, SparseCS) + + att = a @ b + if mask.dtype == torch.bool: + assert not isinstance(mask, SparseCS) + if mask.ndim == 2: + mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) + # mask is presumed false == ignore + att[~mask] = float("-inf") + else: + # mask is presumed additive + # repeat if batch sizes don't match + if ( + not isinstance(mask, SparseCS) + and mask.ndim == 3 + and mask.shape[0] != att.shape[0] + and (att.shape[0] % mask.shape[0]) == 0 + ): + repeat_factor = att.shape[0] // mask.shape[0] + mask = mask.repeat([repeat_factor, 1, 1]) + logger.info("Mismatched batch dimensions for mask, repeating mask.") + att += mask + return att + + +def _softmax(a: torch.Tensor, causal: bool = False) -> torch.Tensor: + if _has_cpp_library and isinstance(a, SparseCS): + return a.softmax() + + if a.is_sparse: + return torch.sparse.softmax(a, dim=a.ndim - 1) + + if _is_triton_available(): + return triton_softmax(a, mask=None, causal=causal) + else: + return torch.softmax(a, dim=a.ndim - 1) + + +if _has_cpp_library: + + class SparseBMM(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b): + a = a.coalesce() + r = torch.bmm(a, b) + ctx.save_for_backward(a, b) + return r + + @staticmethod + def backward(ctx, grad): + a, b = ctx.saved_tensors + + # gradients w.r.t. a + ga = None + if ctx.needs_input_grad[0]: + ga = torch.ops.xformers.matmul_with_mask(grad, b.transpose(-2, -1), a) + + # gradients w.r.t. b + gb = None + if ctx.needs_input_grad[1]: + gb = a.transpose(1, 2).bmm(grad) + + return ga, gb + + def _sparse_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Batch matrix multiply between a sparse matrix and a dense matrix + """ + assert a.ndim == b.ndim == 3 + assert a.shape[0] == b.shape[0] + assert a.shape[2] == b.shape[1] + return SparseBMM.apply(a, b) + + +def bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + if _has_cpp_library: + if isinstance(a, SparseCS): + return a.spmm(b) + if a.is_sparse: + return _sparse_bmm(a, b) + return a @ b + + +def _apply_dropout(att, dropout): + if dropout is None: + return att + + # Dropout chokes on sparse tensors + if _has_cpp_library: + if isinstance(att, SparseCS): + values = att.values.clone() + values = dropout(values) + att = SparseCS.wrap( + att.shape, + values, + att.row_indices, + att.row_offsets, + att.column_indices, + att._transp_info, + ) + elif att.is_sparse: + att = att.coalesce() + values = att.values().clone() # protect against in-place dropout + values = dropout(values) + att = torch.sparse_coo_tensor(att.indices(), values, att.shape) + else: + # Simple dense case + att = dropout(att) + + return att + + # Non optimized vanilla dropout + att = dropout(att) + return att + + +def scaled_query_key_softmax( + q: torch.Tensor, + k: torch.Tensor, + att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], +) -> torch.Tensor: + # TODO assume we have (N, S, hs) instead of (B, nh, S, hs), with N = B x nh + # this is needed due to limitations in sparse_bmm for now + + # Self-attend: (N, S, hs) x (N, hs, S) -> (N, S, S) + q = q / math.sqrt(k.size(-1)) + + # Matmul with mask + if att_mask is not None and isinstance(att_mask, AttentionMask): + # Additive mask + mask: Optional[Union[SparseCS, torch.Tensor]] = att_mask.values + else: + mask = att_mask + + att = _matmul_with_mask(q, k.transpose(-2, -1), mask) + + # Softmax to get the attention probabilities + is_causal = isinstance(att_mask, AttentionMask) and att_mask.is_causal + att = _softmax(att, causal=is_causal) + return att + + +if _is_blocksparse_available: + # 128 is default maxsize + @lru_cache(maxsize=128) + def _retrieve_blocksparse( + num_heads: int, seq_len: int, block_size: int + ) -> BlockSparseAttention: + # Checks if blocksparse object exists in cache + + blocks = seq_len // block_size + layout_fill = torch.ones((num_heads, blocks, blocks), dtype=torch.long) + return BlockSparseAttention( + layout=layout_fill, block_size=block_size, causal=True + ) + + +def blocksparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout: Optional[torch.nn.Module] = None, + block_size: int = 128, +) -> torch.Tensor: + + orig_dim = q.dim() + seq_len = q.shape[-2] + # Layout head dimension: 1 or batch size (q.shape[0]) + layout_heads = 1 + + # TODO perhaps add functionality to pad qkv if sequence length is not divisible by block size? + assert seq_len % block_size == 0, "Sequence length must be divisible by block size" + + if orig_dim == 3: + # Reshape from (N, S, hs) to (B, nh, S, hs) where N = B x nh, hs = D / nh + # Assuming num_heads = 1, (N, S, hs) to (B, 1, S, hs) + if layout_heads == 1: + q = q.unsqueeze(1) + k = k.unsqueeze(1) + v = v.unsqueeze(1) + else: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + blocksparse_attention = _retrieve_blocksparse(layout_heads, seq_len, block_size) + # Dropout is a no-op in evaluation mode + if isinstance(dropout, torch.nn.Dropout): + blocksparse_attention.attn_drop = dropout + else: + blocksparse_attention.attn_drop = torch.nn.Dropout(0.0) + att = blocksparse_attention(q, k, v) + + # Reshape attention (B, nh, S, hs) back to (N, S, hs) + if orig_dim == 3: + return att.flatten(0, 1) + return att + + +def scaled_dot_product_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[AttentionMask, "SparseCS", torch.Tensor]], + dropout: Optional[torch.nn.Module] = None, + block_size: int = 128, +) -> torch.Tensor: + autocast_disabled = ( + _has_cpp_library + and isinstance(att_mask, SparseCS) + or (att_mask is not None and att_mask.is_sparse) + ) + seq_len = q.shape[-2] + + # switch if: + # causal is required but mask is not sparse + # fp16 or under amp context + # sequence length is divisible by block size + # same seq len for K and Q + switch_to_blocksparse = ( + _is_blocksparse_available + and (att_mask is not None and not att_mask.is_sparse) + and (isinstance(att_mask, AttentionMask) and att_mask.is_causal) + and (q.dtype == torch.float16 or torch.is_autocast_enabled()) + and not seq_len % block_size + and q.shape[-2] == k.shape[-2] + ) + + if switch_to_blocksparse: + logger.info("Switching causal attention to Triton blocksparse...") + return blocksparse_attention(q, k, v, dropout, block_size) + + with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore + if autocast_disabled: + q, k, v = q.float(), k.float(), v.float() + + att = scaled_query_key_softmax(q, k, att_mask=att_mask) + + # Optional dropout, could be part of the masking in the future + att = _apply_dropout(att, dropout) + + # Get to the predicted values, for all heads + # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs) + y = bmm(att, v) + return y diff --git a/pkgs/xformers/components/attention/favor.py b/pkgs/xformers/components/attention/favor.py new file mode 100644 index 0000000..57bd776 --- /dev/null +++ b/pkgs/xformers/components/attention/favor.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.cuda.amp import autocast + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.feature_maps import ( + FeatureMap, + FeatureMapType, + SMHyperbolic, + SMOrf, + SMReg, +) + +logger = logging.getLogger("xformers") + + +@dataclass +class FavorAttentionConfig(AttentionConfig): + causal: Optional[bool] + dim_features: Optional[int] = None # The dimensions of the random features + dim_head: Optional[ + int + ] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate + iter_before_redraw: Optional[ + int + ] = None # The number of iterations before the random features are re-drawn from scratch + feature_map: Optional[FeatureMapType] = None + + +@register_attention("favor", FavorAttentionConfig) +class FavorAttention(Attention): + def __init__( + self, + causal: bool = False, + dropout: float = 0.0, + dim_features: Optional[int] = None, + dim_head: Optional[int] = None, + iter_before_redraw: Optional[int] = None, + feature_map_type: FeatureMapType = FeatureMapType.SMReg, + normalize_inputs: bool = False, + *_, + **__, + ): + r""" + Kernelized attention, as proposed in Performers_ + ("Rethinking attention with performers." K. Choromanski et al. (2020).). + + FAVOR stands for "Fast Attention Via positive Orthogonal Random features" + + Args: + dropout (float): the probability of an output to be randomly dropped at training time + dim_features (int): the dimension of the random features space + iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features + feature_map_type (FeatureMapType): the type of feature map being used, + for instance orthogonal random features. + + .. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf + """ + super().__init__() + + self.causal = causal + self.iter_before_redraw = ( + (2 * iter_before_redraw) + if iter_before_redraw is not None + else iter_before_redraw + ) # This will be used for both key and query + self.normalize_inputs = normalize_inputs + self.feature_map_type = feature_map_type + self.attn_drop = nn.Dropout(dropout, inplace=True) + + # Setup dimension-dependent variables + # Reasonable dimension default + if dim_features is None: + assert dim_head is not None, "dim_features or dim_head needs to be passed" + self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head))) + self.dim_features = 2 * ( + self.dim_features // 2 + ) # needs to be even for some variants + logger.info( + f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}" + ) + else: + self.dim_features = dim_features + + feature_map_constructor = { + FeatureMapType.SMHyp: SMHyperbolic, + FeatureMapType.SMReg: SMReg, + FeatureMapType.SMOrf: SMOrf, + }[self.feature_map_type] + + feature_settings = { + "dim_features": self.dim_features, + "iter_before_redraw": self.iter_before_redraw, + "normalize_inputs": self.normalize_inputs, + } + + self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore + + # Properties specific to this attention mechanism + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + @staticmethod + def _maybe_promote(x: torch.Tensor) -> torch.Tensor: + # Only promote fp16 buffers, bfloat16 would be fine for instance + return x.float() if x.dtype == torch.float16 else x + + @staticmethod + def _causal_attention( + k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Algorithm 1 in the paper + ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB + Gps = k_prime.unsqueeze(3) * v.unsqueeze(2) + Grenorm = k_prime.unsqueeze(3) * ref_v + + # Consolidate against the feature dimension + att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime) + att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime) + + # Cumulative sum over the sequence + att_raw = att_raw.cumsum(2) + att_norm = att_norm.cumsum(2) + + return att_raw, att_norm + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *_, + **__, + ): + + # Project key and queries onto the feature map space + k_prime = self.feature_map(k) + q_prime = self.feature_map(q) + + with autocast(enabled=False): + # The softmax kernel approximation for Favor will easily overflow + # Force the computations here to stay in fp32 for numerical stability + # Note that the dimensions are vastly reduced when compared to scaled_dot_product + k_prime = self._maybe_promote(k_prime) + q_prime = self._maybe_promote(q_prime) + v = self._maybe_promote(v) + + if not self.causal: + att_normalization = q_prime @ ( + k_prime.transpose(-2, -1) @ torch.ones_like(v) + ) + att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v) + else: + # Actually compute attention + att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v) + + # Normalize + att = att_raw / att_normalization + + if self.attn_drop is not None: + att = self.attn_drop(att) + + return att diff --git a/pkgs/xformers/components/attention/feature_maps/__init__.py b/pkgs/xformers/components/attention/feature_maps/__init__.py new file mode 100644 index 0000000..ed308d1 --- /dev/null +++ b/pkgs/xformers/components/attention/feature_maps/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import Enum + +from .base import FeatureMap, FeatureMapConfig +from .softmax import NormDistribution, SMHyperbolic, SMOrf, SMReg + + +class FeatureMapType(str, Enum): + SMOrf = "sm_orf" + SMHyp = "sm_hyp" + SMReg = "sm_reg" # regularized softmax kernel + + +__all__ = [ + "SMOrf", + "SMReg", + "SMHyperbolic", + "NormDistribution", + "FeatureMapConfig", + "FeatureMap", +] diff --git a/pkgs/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..1310bc8 Binary files /dev/null and b/pkgs/xformers/components/attention/feature_maps/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/feature_maps/__pycache__/base.cpython-310.pyc b/pkgs/xformers/components/attention/feature_maps/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..c7c67dc Binary files /dev/null and b/pkgs/xformers/components/attention/feature_maps/__pycache__/base.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-310.pyc b/pkgs/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-310.pyc new file mode 100644 index 0000000..3df32c8 Binary files /dev/null and b/pkgs/xformers/components/attention/feature_maps/__pycache__/softmax.cpython-310.pyc differ diff --git a/pkgs/xformers/components/attention/feature_maps/base.py b/pkgs/xformers/components/attention/feature_maps/base.py new file mode 100644 index 0000000..8d41de8 --- /dev/null +++ b/pkgs/xformers/components/attention/feature_maps/base.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import abstractmethod +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch + +""" +Feature maps allow for a given query or key to be encoded in a different space. +""" + +Self = TypeVar("Self", bound="FeatureMap") + + +@dataclass +class FeatureMapConfig: + name: str + dim_features: int + iter_before_redraw: Optional[int] + normalize_inputs: Optional[bool] + epsilon: Optional[float] + + +class FeatureMap(torch.nn.Module): + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int] = None, + normalize_inputs: bool = False, + epsilon: float = 1e-6, + ): + super().__init__() + + self.dim_features = dim_features + self.dim_feature_map = dim_features + + self.iter_before_redraw = iter_before_redraw + self.features: Optional[torch.Tensor] = None + self.epsilon = epsilon + self.normalize_inputs = normalize_inputs + + self._iter_counter = 0 + + @abstractmethod + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + raise NotImplementedError() + + @classmethod + def from_config(cls: Type[Self], config: FeatureMapConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) diff --git a/pkgs/xformers/components/attention/feature_maps/softmax.py b/pkgs/xformers/components/attention/feature_maps/softmax.py new file mode 100644 index 0000000..d0dd1df --- /dev/null +++ b/pkgs/xformers/components/attention/feature_maps/softmax.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from enum import Enum, auto +from typing import Optional + +import torch +from torch.autograd.profiler import record_function + +from .base import FeatureMap + +""" +A set of feature maps which approximate the softmax kernel, as per the Performers_ paper. + +_Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf +""" + + +class NormDistribution(Enum): + Xi = auto() + Uniform = auto() + + +class SoftMaxPositiveEstimators(FeatureMap): + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int], + normalize_inputs: bool = False, + epsilon: float = 1e-6, + softmax_temp: float = -1, + ): + super().__init__(dim_features, iter_before_redraw, normalize_inputs, epsilon) + self.softmax_temp = softmax_temp + + # Handle the scaling from all kernels by √m. + # This normalizes for all the feature maps involved + self.h_scale = math.log(math.sqrt(self.dim_features)) + + def pre_scale(self, x: torch.Tensor) -> torch.Tensor: + with record_function("feature_map::pre_scale"): + # Re-draw counting logic + if ( + ( + self.iter_before_redraw is not None + and self._iter_counter > self.iter_before_redraw + ) + or self.features is None + or self.features.device != x.device + ): + # The feature map is actually using half the dimension, we'll concatenate + and - features + self._iter_counter = 1 + self.features = self._get_feature_map( + x.shape[-1], self.dim_feature_map, x.device + ) + + features = self.features + assert features is not None + + if features.dtype != x.dtype: + self.features = features.to(x.dtype) + + self._iter_counter += 1 + + # Normalization / softmax + if self.softmax_temp < 0: + # A = exp(QK.t/√d), so each input will be scaled by √√d + self.softmax_temp = x.shape[-1] ** -0.25 + + x_scaled = x * self.softmax_temp + + # Compute the scaling factors in logspace, applied from within the exponential + # - dimnish possible exponential overflow + # - remove a multiply across the batch, replace by an addition + norm_x_2 = torch.einsum("...d,...d->...", x_scaled, x_scaled).unsqueeze(-1) + self.offset = -0.5 * norm_x_2 - self.h_scale + self.epsilon + + if self.normalize_inputs: + # L0 normalize the exponential term, can be useful for numerical stability + # This ensures that features +- offset is below 1 + self.offset -= norm_x_2.max(1, keepdim=True)[0] + + # Return the scaled inputs, the rest depends on the kernel being used + return x_scaled + + @staticmethod + @torch.no_grad() + def _get_random_ortho_matrix( + blocks: int, + dim: int, + device: torch.device, + norm_distribution: NormDistribution = NormDistribution.Uniform, + ) -> torch.Tensor: + r""" + Generate a random matrix whose rows are exactly orthonormal + + "How to generate random matrices from the classical compact groups", Mezzadri, 2007 + https://arxiv.org/pdf/math-ph/0609050v2.pdf + + .. note: the typical qr decomposition does not give uniform results, qr decomposition is not + unique and the qr decomposition routines are biased towards numerical stability. See the above + paper for more information. + + .. note: this does not follow the original implementation from the Performers authors. + see docs/assets/kde plots to visualize the impact of using the R signs to correct Q + """ + + H = torch.randn((blocks, dim, dim), device=device, requires_grad=False) + + # Randomly scale the norms of the features, Xi distributed + if norm_distribution == NormDistribution.Xi: + # NOTE: This averages to sqrt(d) + norms = torch.sqrt(torch.einsum("...d,...d->...", H, H)) + + Q, R = torch.linalg.qr(H) + Q = torch.diag_embed(torch.sign(torch.diagonal(R, dim1=1, dim2=2))) @ Q + + # Normalize if need be. Uniform NormDistribution does nothing, Q is already orthonormal + if norm_distribution == NormDistribution.Xi: + return torch.diag_embed(norms) @ Q + + return Q + + +class SMOrf(SoftMaxPositiveEstimators): + """ + "Positive random orthogonal features" softmax estimator, + SM_ort^m+, as proposed in the Performers_ paper, Lemma 1. + + _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf + """ + + @torch.no_grad() + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + """ + Generate the projection matrix onto the random features + + .. note: The heads dimension needs to be taken into account, hence the per-block random matrix + and not uniformally random. + """ + + # Get per block random unitary matrices. + # We need enough of them to project the whole input dimension, regardless of the + # requested dimension of the features + features = self._get_random_ortho_matrix( + math.ceil(dim_input / dim_features), + dim_features, + norm_distribution=NormDistribution.Xi, + device=device, + ) + + return features.flatten(0, 1)[:dim_input] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Softmax-dimension related scaling, shared for all kernels + x_scaled = super().pre_scale(x) + assert self.features is not None + + # Project onto the random feature map. + x_scaled = x_scaled @ self.features + return torch.exp(x_scaled + self.offset) + + +class SMHyperbolic(SoftMaxPositiveEstimators): + """ + "Positive random features hyperbolic" estimator, SMHyp+, + as proposed in the Performers_ paper, Lemma 1. + + _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf + """ + + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int], + normalize_inputs: bool = False, + epsilon: float = 1e-6, + softmax_temp: float = -1, + ): + super().__init__( + dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp + ) + + assert ( + dim_features % 2 == 0 + ), "The feature dimension needs to be even with this kernel" + self.dim_feature_map = self.dim_features // 2 + + @torch.no_grad() + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + """ + Generate the projection matrix onto the random features + + .. note: The heads dimension needs to be taken into account, hence the per-block random matrix + and not uniformally random. + """ + + # Get per block random unitary matrices. + # We need enough of them to project the whole input dimension, regardless of the + # requested dimension of the features + features = self._get_random_ortho_matrix( + math.ceil(dim_input / dim_features), + dim_features, + norm_distribution=NormDistribution.Xi, + device=device, + ) + + return features.flatten(0, 1)[:dim_input] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Softmax-dimension related scaling, shared for all kernels + x_scaled = super().pre_scale(x) + + # Project onto the random feature map, concatenate both + and - results + # This follows Lemma 1 in the original Performers Paper to best approximate a + # softmax kernel (cosh representation) + x_scaled = x_scaled @ self.features + return torch.cat( + [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)], + dim=-1, + ) + + +class SMReg(SoftMaxPositiveEstimators): + """ + "Regularized softmax kernel" estimator, SMREG+, as proposed in the Performers_ paper. + + _Performers: "Rethinking attention with performers." K. Choromanski et al. (2020). + https://arxiv.org/pdf/2009.14794v1.pdf + """ + + def __init__( + self, + dim_features: int, + iter_before_redraw: Optional[int], + normalize_inputs: bool = False, + epsilon: float = 1e-6, + softmax_temp: float = -1, + ): + super().__init__( + dim_features, iter_before_redraw, normalize_inputs, epsilon, softmax_temp + ) + + assert ( + dim_features % 2 == 0 + ), "The feature dimension needs to be even with this kernel" + self.dim_feature_map = self.dim_features // 2 + + @torch.no_grad() + def _get_feature_map(self, dim_input: int, dim_features: int, device: torch.device): + """ + Generate the projection matrix onto the random features + + .. note: The heads dimension needs to be taken into account, hence the per-block random matrix + and not uniformally random. + """ + + # Get per block random unitary matrices. + # We need enough of them to project the whole input dimension, regardless of the + # requested dimension of the features + features = self._get_random_ortho_matrix( + math.ceil(dim_input / dim_features), + dim_features, + norm_distribution=NormDistribution.Uniform, + device=device, + ).flatten(0, 1) + norms = math.sqrt(dim_input) * torch.ones(features.shape[0], device=device) + return (torch.diag(norms) @ features)[:dim_input] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Softmax-dimension related scaling, shared for all kernels + x_scaled = super().pre_scale(x) + + # Project onto the random feature map, concatenate both + and - results + # This follows Lemma 1 in the original Performers Paper to best approximate a + # softmax kernel (cosh representation + sample regularization) + x_scaled = x_scaled @ self.features + return torch.cat( + [torch.exp(x_scaled + self.offset), torch.exp(-x_scaled + self.offset)], + dim=-1, + ) diff --git a/pkgs/xformers/components/attention/fourier_mix.py b/pkgs/xformers/components/attention/fourier_mix.py new file mode 100644 index 0000000..8ceaa01 --- /dev/null +++ b/pkgs/xformers/components/attention/fourier_mix.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.cuda.amp import autocast + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@register_attention("fourier_mix", AttentionConfig) +class FourierMix(Attention): + def __init__(self, dropout: float, *_, **__): + """ + FFT-based pseudo-attention mechanism, from + " + "FNet: Mixing Tokens with Fourier Transforms" + Lee-Thorp et al., 2021, https://arxiv.org/pdf/2105.03824.pdf + """ + super().__init__() + self.attn_drop = torch.nn.Dropout(dropout, inplace=False) + + # Properties specific to this attention mechanism + self.supports_attention_mask = False + self.requires_input_projection = False + + def forward(self, q: torch.Tensor, *_, **__): + # Guard against autocast / fp16, not supported by torch.fft.fft2 + with autocast(enabled=False): + att = torch.fft.fft2(q).real + + att = self.attn_drop(att) + + return att diff --git a/pkgs/xformers/components/attention/global_tokens.py b/pkgs/xformers/components/attention/global_tokens.py new file mode 100644 index 0000000..c6a5284 --- /dev/null +++ b/pkgs/xformers/components/attention/global_tokens.py @@ -0,0 +1,122 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + maybe_sparsify, + register_attention, + sparsify, +) +from xformers.components.attention.attention_patterns import ( + causal_1d_pattern, + global_token_pattern, +) +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class GlobalAttentionConfig(AttentionConfig): + attention_query_mask: torch.Tensor # Mark the queries which have global attention + causal: Optional[bool] + force_sparsity: Optional[bool] + + +@register_attention("global", GlobalAttentionConfig) +class GlobalAttention(Attention): + def __init__( + self, + dropout: float, + attention_query_mask: torch.Tensor, + causal: bool = False, + force_sparsity: bool = False, + *_, + **__, + ): + r""" + Global attention, as proposed for instance in BigBird_ or Longformer_. + + Global means in that case that the queries positively labelled in the ```attention_query_mask``` can attend + to all the other queries. The queries negatively labelled in the ```attention_query_mask``` cannot attend to + any other query. + + This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory. + + Args: + dropout (float): probability of an element to be zeroed + attention_query_mask (torch.Tensor): if true, this query can attend to all the others + + """ + super().__init__() + + assert attention_query_mask.dtype == torch.bool, "A boolean mask is expected" + assert ( + attention_query_mask.shape[1] == 1 + and attention_query_mask.shape[0] > attention_query_mask.shape[1] + ), "A N x 1 query mask is expected" + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.attention_mask = global_token_pattern(attention_query_mask[:, 0]) + self.force_sparsity = force_sparsity + + if causal: + self.attention_mask &= causal_1d_pattern(attention_query_mask.shape[1]) + + self.attention_mask = ( + sparsify(self.attention_mask) + if self.force_sparsity + else maybe_sparsify(self.attention_mask) + ) + + # Properties specific to this attention mechanism + self.requires_same_k_q_dimensions = True + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + *_, + **__, + ): + # Make sure that the mask is on the right device + if self.attention_mask.device != q.device: + self.attention_mask = self.attention_mask.to(q.device) + + # Mask-aware attention + if att_mask is not None: + if att_mask.dtype == torch.bool and isinstance( + self.attention_mask, AttentionMask + ): + if not isinstance(att_mask, AttentionMask): + att_mask = AttentionMask.from_bool(att_mask) + mask = self.attention_mask + att_mask + else: + mask = self.attention_mask & att_mask + else: + mask = self.attention_mask + + # Handle q/k/v which would not fit the mask + seq_len = q.shape[-2] + q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v)) + + # Normal attention with the global tokens mask + att = scaled_dot_product_attention( + q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop + ) + + # Take into account an hypothetical padding + return att[:, :seq_len, :] diff --git a/pkgs/xformers/components/attention/lambda_layer.py b/pkgs/xformers/components/attention/lambda_layer.py new file mode 100644 index 0000000..0002a20 --- /dev/null +++ b/pkgs/xformers/components/attention/lambda_layer.py @@ -0,0 +1,78 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +import torch + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +def calc_rel_pos(n: int): + # Adapted from LucidRains + # https://github.com/lucidrains/lambda-networks/blob/main/lambda_networks/lambda_networks.py + rel_pos = torch.arange(n)[None, :] - torch.arange(n)[:, None] # [n, n] + rel_pos += n - 1 # shift value range from [-n+1, n-1] to [0, 2n-2] + return rel_pos + + +@dataclass +class LambdaLayerConfig(AttentionConfig): + seq_len: int # dimension of the input sequence + dim_head: int + + +@register_attention("lambda", LambdaLayerConfig) +class LambdaLayer(Attention): + def __init__(self, dropout: float, seq_len: int, dim_head: int, *_, **__): + """ + Attention approximation using Lambda layers, from + "Lambda networks: modeling long-range interactions without attention.", Bello, I. (2021). + """ + super().__init__() + + # Possible extensions: + # - support different dimensions for key and queries + # - support varying dimensions in between inputs and outputs + # - support u hyperparam + + self.rel_pos_emb = torch.nn.Parameter( + torch.randn(2 * seq_len - 1, int(dim_head)) + ) + self.rel_pos = calc_rel_pos(seq_len) + self.attn_drop = torch.nn.Dropout(dropout, inplace=True) + + # Properties specific to this attention mechanism + self.requires_same_k_q_dimensions = True + self.supports_attention_mask = False + self.supports_key_padding_mask = False + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ): + """..NOTE: We're reusing the einsum notation suggested by the paper, changed in that + heads are folded in the batch dimension""" + + content_lambda = torch.einsum("bnk,bnv->bkv", torch.softmax(k, dim=-1), v) + content_output = torch.einsum("bnk,bkv->bnv", q, content_lambda) + + rel_pos_emb = self.rel_pos_emb[self.rel_pos] + + # Handle real sequence length being possibly smaller + seq_len = q.shape[1] + rel_pos_emb = rel_pos_emb[:seq_len, :seq_len, :] + + # Compute the position lambda for every possible combination in one go, then compute the + # position related contribution + position_lambdas = torch.einsum( + "mnk,bnv->bnkv", rel_pos_emb, v + ) # one lambda per position + position_output = (q.unsqueeze(2) @ position_lambdas).squeeze() + att = content_output + position_output + + att = self.attn_drop(att) + + return att diff --git a/pkgs/xformers/components/attention/linformer.py b/pkgs/xformers/components/attention/linformer.py new file mode 100644 index 0000000..af6f20b --- /dev/null +++ b/pkgs/xformers/components/attention/linformer.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class LinformerSelfAttentionConfig(AttentionConfig): + seq_len: int # dimension of the input sequence + k: Optional[int] # dimension of the internal space + + +@register_attention("linformer", LinformerSelfAttentionConfig) +class LinformerAttention(Attention): + def __init__( + self, dropout: float, seq_len: int, k: Optional[int] = None, *args, **kwargs + ): + """ + Linformer attention mechanism, + from `Linformer: Self-Attention with Linear Complexity`_, Wang et al (2020). + The original notation is kept as is. + + .. _`Linformer: Self-Attention with Linear Complexity` : https://arxiv.org/abs/2006.04768v2 + """ + super().__init__() + + if k is None: + k = seq_len // 4 + + self.k = k + self.E = nn.Linear(seq_len, k, bias=False) + self.F = nn.Linear(seq_len, k, bias=False) + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.seq_len = seq_len + + # MHA related flags: + # kq need to have the same dimension + self.requires_same_k_q_dimensions = True + + # This attention does not support attention masks + self.supports_attention_mask = False + + def forward( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, *args, **kwargs + ): + # Handle a smaller dimension than expected + padding = 0 + if q.shape[1] < self.seq_len: + padding = self.seq_len - q.shape[1] + pad_dims = (0, 0, 0, padding) + q = torch.nn.functional.pad(q, pad_dims) + k = torch.nn.functional.pad(k, pad_dims) + v = torch.nn.functional.pad(v, pad_dims) + + k_projected = self.E(k.transpose(-2, -1)).transpose(-2, -1) + v_projected = self.F(v.transpose(-2, -1)).transpose(-2, -1) + + y = scaled_dot_product_attention( + q=q, k=k_projected, v=v_projected, att_mask=None, dropout=self.attn_drop + ) + + y = self.attn_drop(y) + + return y[:, :-padding, :] if padding > 0 else y diff --git a/pkgs/xformers/components/attention/local.py b/pkgs/xformers/components/attention/local.py new file mode 100644 index 0000000..3220a8d --- /dev/null +++ b/pkgs/xformers/components/attention/local.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + maybe_sparsify, + register_attention, + sparsify, +) +from xformers.components.attention.attention_patterns import ( + causal_1d_pattern, + local_1d_pattern, +) +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class LocalAttentionConfig(AttentionConfig): + causal: Optional[bool] = None + window_size: Optional[int] = None + force_sparsity: Optional[bool] = None + + +@register_attention("local", LocalAttentionConfig) +class LocalAttention(Attention): + def __init__( + self, + dropout: float = 0.0, + causal: bool = False, + window_size: int = 5, + force_sparsity: bool = False, + *args, + **kwargs, + ): + + r""" + An implementation of a sliding window attention, as proposed in RoutingTransformer_, LongFormer_ or BigBird_ + + + Args: + dropout (float): the probability of an output to be randomly dropped at training time + causal (bool): apply a causal mask, in that the attention cannot be applied to the future + window_size (int): the overall window size for local attention. + Odd number is expected if the mask is not causal, as the window size will be evenly + distributed on both sides of each query + + + .. _RoutingTransformer: https://arxiv.org/pdf/2003.05997.pdf + + .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf + + .. _Longformer: https://arxiv.org/pdf/2004.05150.pdf + + """ + super().__init__() + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.causal = causal + self.force_sparsity = force_sparsity + + if not self.causal: + assert ( + window_size % 2 == 1 + ), "The window size is assumed to be odd (counts self-attention + 2 wings)" + + self.window_size = window_size + self.attention_mask: Optional[torch.Tensor] = None + self.requires_same_k_q_dimensions = True + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + def _get_local_mask(self, shape: torch.Size) -> torch.Tensor: + window_size = self.window_size * 2 + 1 if self.causal else self.window_size + mask = local_1d_pattern(shape[1], window_size) + + if self.causal: + mask &= causal_1d_pattern(shape[1]) + + mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask) + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + *args, + **kwargs, + ): + # Local window attention masking + if self.attention_mask is None or self.attention_mask.shape[1] != q.shape[1]: + self.attention_mask = self._get_local_mask(q.shape).to(q.device) + + # Take into account the optional user mask + if att_mask is None: + mask = self.attention_mask + else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() + mask = self.attention_mask & att_mask + + return scaled_dot_product_attention( + q=q, k=k, v=v, att_mask=mask, dropout=self.attn_drop + ) diff --git a/pkgs/xformers/components/attention/nystrom.py b/pkgs/xformers/components/attention/nystrom.py new file mode 100644 index 0000000..93e40b7 --- /dev/null +++ b/pkgs/xformers/components/attention/nystrom.py @@ -0,0 +1,295 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention +from xformers.components.attention.core import ( + scaled_dot_product_attention, + scaled_query_key_softmax, +) +from xformers.components.attention.utils import ( + bool_mask_to_additive, + iterative_pinv, + reshape_key_padding_mask, +) + +logger = logging.getLogger("xformers") + + +@dataclass +class NystromSelfAttentionConfig(AttentionConfig): + """ + num_heads Number of heads. + num_landmarks Number of landmarks to use for softmax approximation. 64 often sufficient for a good + approximation according to https://arxiv.org/pdf/2102.03902.pdf. + causal Apply a causal mask, in that the attention cannot be applied to the future. + use_razavi_pinverse If true, use iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose + inverse, otherwise use standard torch inverse. + pinverse_original_init True if using original initialization when calculating Moore-Penrose pseudo inverse using + method from (Razavi et al. 2014). + False if using exact coefficient computation (leads to faster convergence). + inv_iterations Number of iterations for calculating the Moore-Penrose pseudo inverse. + v_skip_connection A module that will take V as input and will be added as a skip connection to the + softmax approximation. A skip connection is added in the paper to help with training. + conv_kernel_size Kernel size for convolution optionally added to help in training. + If v_skip_connection is not specified, this will be used to define the default + depth wise convolution used as a skip connection. + If both conv_kernel_size and v_skip_connection are None, no skip connection will + be added. + landmark_pooling Which module to use when computing landmarks. Default is AdaptiveAvgPool2d. + """ + + num_heads: int + num_landmarks: Optional[int] + landmark_pooling: Optional[nn.Module] + causal: Optional[bool] + pinverse_original_init: Optional[bool] + inv_iterations: Optional[int] + v_skip_connection: Optional[nn.Module] + conv_kernel_size: Optional[int] + use_razavi_pinverse: Optional[bool] + + +class AvgPool(nn.Module): + def __init__(self, n: int): + super().__init__() + self.n = n + + def forward(self, x: torch.Tensor): + # Average independently for every segment in the sequence dimension + seq_len = x.shape[1] + head_dim = x.shape[2] + segments = seq_len // self.n + assert segments > 0, "num_landmarks should be smaller than the sequence length" + + # Dimensions are a match + if seq_len % self.n == 0: + return x.reshape( + -1, + self.n, + segments, + head_dim, + ).mean(dim=-2) + + # Handle the last segment boundary being off + n_round = self.n - seq_len % self.n + + x_avg_round = ( + x[:, : n_round * segments, :] + .reshape(-1, n_round, segments, head_dim) + .mean(dim=-2) + ) + x_avg_off = ( + x[:, n_round * segments :, :] + .reshape(-1, self.n - n_round, segments + 1, head_dim) + .mean(dim=-2) + ) + return torch.cat((x_avg_round, x_avg_off), dim=-2) + + +@register_attention("nystrom", NystromSelfAttentionConfig) +class NystromAttention(Attention): + # TODO: update defaults for use_razavi_pinverse and inv_iterations + def __init__( + self, + dropout: float, + num_heads: int, + num_landmarks: int = 64, + landmark_pooling: Optional[nn.Module] = None, + causal: bool = False, + use_razavi_pinverse: bool = True, + pinverse_original_init: bool = False, + inv_iterations: int = 6, # recommended default in paper was 6. + v_skip_connection: Optional[nn.Module] = None, + conv_kernel_size: Optional[int] = None, + *args, + **kwargs, + ): + """ + Nystrom attention mechanism, from Nystromformer_. + :: + + "A Nystrom-based Algorithm for Approximating Self-Attention." + Xiong, Y., Zeng, Z., Chakraborty, R., Tan, M., Fung, G., Li, Y., Singh, V. (2021) + + Reference codebase: https://github.com/mlpen/Nystromformer + + .. _Nystromformer: https://arxiv.org/pdf/2102.03902.pdf + + """ + super().__init__() + # merged key padding mask and attention mask is not accepted + self.requires_separate_masks = True + self.num_landmarks = num_landmarks + # TODO: should be able to not have to pass in num_heads + self.num_heads = num_heads + self.use_razavi_pinverse = use_razavi_pinverse + self.pinverse_original_init = pinverse_original_init + self.inv_iterations = inv_iterations + self.attn_drop = nn.Dropout(dropout) + self.skip_connection = v_skip_connection + self.causal = causal + + if self.skip_connection is None and conv_kernel_size is not None: + self.skip_connection = nn.Conv2d( + in_channels=self.num_heads, + out_channels=self.num_heads, + kernel_size=(conv_kernel_size, 1), + padding=(conv_kernel_size // 2, 0), + bias=False, + groups=self.num_heads, + ) + + if landmark_pooling is not None: + self.landmark_pooling = landmark_pooling + else: + self.landmark_pooling = AvgPool(n=self.num_landmarks) + + # Optional lower triangular masks for causal attention + self.causal_mask_1: Optional[torch.Tensor] = None + self.causal_mask_2: Optional[torch.Tensor] = None + self.causal_mask_3: Optional[torch.Tensor] = None + + # This attention does not support attention masks + self.supports_attention_mask = False + self.supports_key_padding_mask = True + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + *args, + **kwargs, + ): + r""" + key_padding_mask Only a key padding mask is accepted here. The size must be (batch size, sequence length) or + (batch size * num_heads, 1, sequence length). If dimensions are not correct, the mask will + be ignored. An additive mask is expected, meaning float values using "-inf" to mask values + """ + + batched_dim = k.size(0) + seq_len = k.size(-2) + tt = {"dtype": q.dtype, "device": q.device} + + if key_padding_mask is not None: + if key_padding_mask.dtype == torch.bool: + logger.warning( + "Bool mask found, but an additive mask is expected. Converting but this is slow" + ) + + key_padding_mask = bool_mask_to_additive(key_padding_mask) + + if key_padding_mask.ndim == 2: + key_padding_mask = reshape_key_padding_mask( + key_padding_mask, batched_dim + ) + + zeros = torch.zeros_like(key_padding_mask) + ones = torch.ones_like(key_padding_mask) + is_masked = torch.isinf(-key_padding_mask) + + # _mask takes 1 if the token is not padded, otherwise 0. + _mask = torch.where(is_masked, zeros, ones) + _mask = _mask.transpose(2, 1) + assert _mask.shape == (batched_dim, q.shape[1], 1) + + # Mask q and k before pooling + # https://github.com/mlpen/Nystromformer/blob/main/code/attention_nystrom.py#L31 + q = q * _mask + k = k * _mask + + assert key_padding_mask.size() == (batched_dim, 1, seq_len), ( + f"key_padding_mask has invalid dimensions {key_padding_mask.size()}." + f" Must have dimensions {batched_dim, 1, seq_len} or (batch_size, {seq_len})." + ) + + if self.num_landmarks >= seq_len: + mask: Optional[torch.Tensor] = None + + if self.causal: + mask = self._triu_mask(batched_dim, seq_len, seq_len, **tt) + + if key_padding_mask is not None: + mask = key_padding_mask if mask is None else mask + key_padding_mask + + x = scaled_dot_product_attention(q=q, k=k, v=v, att_mask=mask) + + else: + q_landmarks = self.landmark_pooling(q) + k_landmarks = self.landmark_pooling(k) + + if self.causal and ( + self.causal_mask_1 is None + or (batched_dim, seq_len, self.num_landmarks) + != self.causal_mask_1.size() + ): + self.causal_mask_1 = self._triu_mask( + batched_dim, seq_len, self.num_landmarks, **tt + ) + self.causal_mask_2 = self._triu_mask( + batched_dim, self.num_landmarks, self.num_landmarks, **tt + ) + self.causal_mask_3 = self._triu_mask( + batched_dim, self.num_landmarks, seq_len, **tt + ) + + mask_3: Optional[torch.Tensor] = self.causal_mask_3 + if key_padding_mask is not None: + mask_3 = ( + key_padding_mask if mask_3 is None else mask_3 + key_padding_mask + ) + + kernel_1 = scaled_query_key_softmax(q=q, k=k_landmarks, att_mask=None) + kernel_2 = scaled_query_key_softmax( + q=q_landmarks, k=k_landmarks, att_mask=None + ) + kernel_3 = scaled_dot_product_attention( + q=q_landmarks, k=k, v=v, att_mask=mask_3 + ) + + kernel_2_inv = ( + iterative_pinv( + kernel_2, self.inv_iterations, self.pinverse_original_init + ) + if self.use_razavi_pinverse + else torch.linalg.pinv(kernel_2) + ) + + x = torch.matmul( + torch.matmul( + kernel_1, + kernel_2_inv, + ), + kernel_3, + ) + + if self.skip_connection: + # Assumption here is that v is 3D. + v_conv = self.skip_connection( + v.reshape(-1, self.num_heads, v.size(-2), v.size(-1)) + ) + x += v_conv.reshape(-1, v_conv.size(-2), v_conv.size(-1)) + x = self.attn_drop(x) + return x + + def _triu_mask(self, dim_1: int, dim_2: int, dim_3: int, **kwargs) -> torch.Tensor: + device = kwargs["device"] + dtype = kwargs["dtype"] + + return torch.triu( + torch.ones(dim_2, dim_3, dtype=dtype, device=device) * float("-inf"), + diagonal=1, + ).expand( + dim_1, -1, -1 + ) # micro optim, save memory on the batch dimension diff --git a/pkgs/xformers/components/attention/ortho.py b/pkgs/xformers/components/attention/ortho.py new file mode 100644 index 0000000..3d6de43 --- /dev/null +++ b/pkgs/xformers/components/attention/ortho.py @@ -0,0 +1,324 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + +import torch +import torch.autograd.profiler as profiler +import torch.nn as nn +import torch.nn.functional as Fn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) +from xformers.components.attention.core import ( + scaled_dot_product_attention, + scaled_query_key_softmax, +) + +logger = logging.getLogger("xformers") + + +class LandmarkSelection(str, Enum): + Orthogonal = "orthogonal" + KMeans = "kmeans" + KMeans_Spherical = "kmeans_spherical" + Random = "random" + + +@dataclass +class OrthoformerAttentionConfig(AttentionConfig): + """ + num_landmarks Number of landmarks to use for softmax approximation. + subsample_fraction Percentage of q_samples matrix to sample per iteration + landmark_selection Landmark selection strategy + """ + + num_landmarks: Optional[int] + subsample_fraction: Optional[float] + landmark_selection: Optional[LandmarkSelection] + + +@register_attention("orthoformer", OrthoformerAttentionConfig) +class OrthoFormerAttention(Attention): + def __init__( + self, + dropout: float, + num_landmarks: int = 32, + subsample_fraction: float = 1.0, + landmark_selection: LandmarkSelection = LandmarkSelection.Orthogonal, + *args, + **kwargs, + ): + """ + Orthoformer_ attention mechanism. + :: + + "Keeping Your Eye on the Ball: Trajectory Attention in Video Transformers" + Patrick, M., Campbell, D., Asano, Y., Misra, I., Metze, F., Feichtenhofer, + C., Vedaldi, A., Henriques, J. (2021) + + Reference codebase: https://github.com/facebookresearch/Motionformer + + .. _Orthoformer: https://arxiv.org/abs/2106.05392 + + """ + super().__init__() + + self.num_landmarks = num_landmarks + self.attn_drop = nn.Dropout(dropout) + self.subsample_fraction = subsample_fraction + self.landmark_selection = landmark_selection + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, + *args, + **kwargs, + ): + N = k.shape[1] + + if self.num_landmarks == N: + # Default attention + x = scaled_dot_product_attention(q, k, v, att_mask) + else: + with torch.no_grad(), profiler.record_function("select landmarks"): + if self.landmark_selection == LandmarkSelection.Orthogonal: + landmarks = self._compute_orthogonal_landmarks(q) + elif self.landmark_selection == LandmarkSelection.Random: + half_L = self.num_landmarks // 2 + landmarks_q = q[:, torch.randint(q.size(1), (half_L,)), :] + landmarks_k = k[:, torch.randint(k.size(1), (half_L,)), :] + landmarks = torch.cat((landmarks_q, landmarks_k), dim=-2) + elif self.landmark_selection == LandmarkSelection.KMeans: + landmarks = self._cluster_landmarks(q) + elif self.landmark_selection == LandmarkSelection.KMeans_Spherical: + landmarks = self._cluster_landmarks(q, spherical=True) + + if att_mask is not None: + logger.warning( + "Orthoformer: attention mask passed alongside with using landmarks to reduce dimensions. \ + The two are typically not compatible" + ) + # FIXME: Should we still accept a mask in that case ? + att_mask = None + + # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems + # like it could be uninitialized. + kernel_1 = scaled_query_key_softmax(q, landmarks, att_mask) + # pyre-ignore[61]: TODO(T103337542): `landmarks` mistakenly seems + # like it could be uninitialized. + kernel_2 = scaled_query_key_softmax(landmarks, k, att_mask) + x = torch.matmul(kernel_1, torch.matmul(kernel_2, v)) + x = self.attn_drop(x) + return x + + def _cluster_landmarks( + self, + q: torch.Tensor, + spherical: bool = False, + num_iters: int = 6, + ) -> torch.Tensor: + """ + Construct set of landmarks by recursively selecting new landmarks + that are maximally orthogonal to the existing set. + Returns near orthogonal landmarks with shape (B, M, D). + """ + + num_landmarks = min(self.num_landmarks, q.shape[1]) + + if self.subsample_fraction < 1.0: + num_samples = max( + int(self.subsample_fraction * q.size(-2)), num_landmarks + ) # Need at least M/2 samples of queries and keys + q_samples = q[:, torch.randint(q.size(-2), (num_samples,)), :] # (B, N, D) + else: + q_samples = q # (B, N, D) + + if spherical: + q_samples_normalized = Fn.normalize( + q_samples, p=2, dim=-1 + ) # may need to change default eps to eps=1e-8 for mixed precision compatibility + landmarks = self._kmeans_spherical( + q_samples_normalized, num_landmarks, num_iters + ) + else: + landmarks = self._kmeans(q_samples, num_landmarks, num_iters) + return landmarks # (B, M, D) + + def _kmeans(self, x: torch.Tensor, K: int, num_iters: int = 10): + """ + Arguments: + x: (B, N, D) + K: number of clusters + num_iters: the number of kmeans updates + """ + + B, N, D = x.size() + assert K <= N, f"{K} > {N}" + + c = x[ + :, torch.randperm(N, device=x.device)[:K], : + ].clone() # initialisation for the centroids + + with profiler.record_function("kmeans"): + x_i = x.view(B, N, 1, D) + c_j = c.view(B, 1, K, D) + counts = c.new_zeros(B, K) + ones = x.new_ones((B, N)) + + for _ in range(num_iters): + # E step: assign points to the nearest cluster + D_ij = ((x_i - c_j) ** 2).sum(-1) # (B, N, K) squared distances + cl = D_ij.argmin( + dim=-1, keepdim=True + ).long() # (B, N, 1) index of point to nearest cluster + + # M step: update the centroids + c.zero_() + c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster + counts.fill_(1e-6) # avoid div0 + counts.scatter_add_( + -1, cl.squeeze(-1), ones + ) # number of points per cluster + c.divide_(counts.unsqueeze(-1)) # compute the average + + return c + + def _kmeans_spherical(self, x: torch.Tensor, K: int, num_iters=10): + """ + Arguments: + x: (B, N, D) + """ + B, N, D = x.size() + assert K <= N, f"{K} > {N}" + + # initialisation for the centroids + c = x[:, torch.randperm(N, device=x.device)[:K], :].clone() + + with profiler.record_function("kmeans_spherical"): + counts = c.new_zeros(B, K) + ones = x.new_ones((B, N)) + + for _ in range(num_iters): + # E step: assign points to the nearest cluster + D_ij = torch.matmul( + x, c.transpose(-2, -1) + ) # (B, N, K) cosine similarity + cl = D_ij.argmax( + dim=-1, keepdim=True + ).long() # (B, N, 1) index of point to nearest cluster + + # M step: update the centroids + c.zero_() + c.scatter_add_(-2, cl.repeat(1, 1, D), x) # sum of points per cluster + counts.fill_(1e-6) # avoid div0 + counts.scatter_add_( + -1, cl.squeeze(-1), ones + ) # number of points per cluster + c.divide_(counts.unsqueeze(-1)) # compute the average + c = Fn.normalize(c, p=2, dim=-1) # renormalise + return c + + def _compute_orthogonal_landmarks(self, q: torch.Tensor) -> torch.Tensor: + """ + Construct set of landmarks by recursively selecting new landmarks + that are maximally orthogonal to the existing set. + Returns near orthogonal landmarks with shape (B, M, D). + """ + + if self.subsample_fraction < 1.0: + # Need at least M samples of queries + num_samples = max( + int(self.subsample_fraction * q.size(-2)), self.num_landmarks + ) + q_samples = q[ + :, torch.randint(q.size(-2), (num_samples,), device=q.device), : + ] + else: + # (B, N, D) + q_samples = q + + # may need to change default eps to eps=1e-8 for mixed precision compatibility + q_samples_normalized = Fn.normalize(q_samples, p=2, dim=-1) + B, N, D = q_samples_normalized.shape + + selected_mask = torch.zeros((B, N, 1), device=q_samples_normalized.device) + landmark_mask = torch.ones( + (B, 1, 1), dtype=selected_mask.dtype, device=q_samples_normalized.device + ) + + #  Get initial random landmark + random_idx = torch.randint( + q_samples_normalized.size(-2), (B, 1, 1), device=q_samples_normalized.device + ) + selected_mask.scatter_(-2, random_idx, landmark_mask) + + #  Selected landmarks + selected_landmarks = torch.empty( + (B, self.num_landmarks, D), + device=q_samples_normalized.device, + dtype=q_samples_normalized.dtype, + ) + selected_landmarks[:, 0, :] = q_samples_normalized[ + torch.arange(q_samples_normalized.size(0)), random_idx.view(-1), : + ].view(B, D) + + # Store computed cosine similarities + cos_sims = torch.empty( + (B, N, self.num_landmarks), + device=q_samples_normalized.device, + dtype=q_samples_normalized.dtype, + ) + + for M in range(1, self.num_landmarks): + with profiler.record_function("find new landmark"): + #  Calculate absolute cosine similarity between selected and unselected landmarks + # (B, N, D) * (B, D) -> (B, N) + cos_sims[:, :, M - 1] = torch.einsum( + "b n d, b d -> b n", + q_samples_normalized, + selected_landmarks[:, M - 1, :], + ).abs() + + # (B, N, M) cosine similarities of current set of landmarks wrt all queries and keys + cos_sim_set = cos_sims[:, :, :M] + + #  Get orthogonal landmark: landmark with smallest absolute cosine similarity: + # set cosine similarity for already selected landmarks to > 1 + cos_sim_set.view(-1, M)[selected_mask.flatten().bool(), :] = 10 + + # (B,) - want max for non + selected_landmark_idx = cos_sim_set.amax(-1).argmin(-1) + + #  Add most orthogonal landmark to selected landmarks: + selected_landmarks[:, M, :] = q_samples_normalized[ + torch.arange(q_samples_normalized.size(0)), selected_landmark_idx, : + ].view(B, D) + + #  Removed selected indices from non-selected mask: + selected_mask.scatter_( + -2, selected_landmark_idx.unsqueeze(-1).unsqueeze(-1), landmark_mask + ) + + # (B, M, D) + landmarks = torch.masked_select(q_samples, selected_mask.bool()).reshape( + B, -1, D + ) + return landmarks # (B, M, D) diff --git a/pkgs/xformers/components/attention/pooling.py b/pkgs/xformers/components/attention/pooling.py new file mode 100644 index 0000000..6c93193 --- /dev/null +++ b/pkgs/xformers/components/attention/pooling.py @@ -0,0 +1,82 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@dataclass +class PoolingAttentionConfig(AttentionConfig): + pool_size: int # dimension of the input sequence + stride: Optional[int] # dimension of the internal space + padding: Optional[int] + + +@register_attention("pooling", PoolingAttentionConfig) +class Pooling(Attention): + def __init__( + self, + pool_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + *_, + **__, + ): + """ + Pooling token mixing mechanism, as proposed in + `Metaformer is actually what you need for vision`_, Yu et al (2021). + + The original notation is kept as is. + + .. _`Metaformer is actually what you need for vision` : https://arxiv.org/pdf/2111.11418v1.pdf + """ + super().__init__() + + padding = padding if padding is not None else pool_size // 2 + self.pool = nn.AvgPool2d( + pool_size, + stride=stride, + padding=pool_size // 2, + count_include_pad=False, + ) + + # MHA related flags: + # kq need to have the same dimension + self.requires_same_k_q_dimensions = False + + # This attention does not support attention masks + self.supports_attention_mask = False + + # This "attention" (token mixing) skips the multihead attention altogether + self.requires_skip_multi_head = True + self.requires_input_projection = False + + # This operator does not really handle q,k,v + self.requires_same_k_q_dimensions = True + + # This attention requires the 2d structure out of the context, + # implictly assumed to be a squared length + self.requires_squared_context = True + + def forward(self, q: torch.Tensor, *_, **__): + # Expose the 2D token structure + B, HW, C = q.shape + H = int(math.sqrt(HW)) + assert H * H == HW + + q = q.transpose(-2, -1).reshape(B, C, H, H) + + # 2D pool + x_pool = self.pool(q) - q # compensate for the residual path + + # Get back to B HW C + return x_pool.flatten(2, 3).transpose(-2, -1) diff --git a/pkgs/xformers/components/attention/random.py b/pkgs/xformers/components/attention/random.py new file mode 100644 index 0000000..e07e6c8 --- /dev/null +++ b/pkgs/xformers/components/attention/random.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional, Union + +import torch +import torch.nn as nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + maybe_sparsify, + register_attention, + sparsify, +) +from xformers.components.attention.attention_patterns import ( + causal_1d_pattern, + random_pattern, +) +from xformers.components.attention.core import scaled_dot_product_attention + + +@dataclass +class RandomAttentionConfig(AttentionConfig): + r: Optional[ + float + ] # the ratio of keys that the query can attend to. 1.0 means dense attention + constant_masking: Optional[ + bool + ] # whether the randomness is per query or defined at construction time + force_sparsity: Optional[bool] # use sparsity in any case (potentially slower) + + +@register_attention("random", RandomAttentionConfig) +class RandomAttention(Attention): + def __init__( + self, + dropout: float, + causal: bool = False, + r: float = 0.01, + constant_masking: bool = True, + force_sparsity: bool = False, + *args, + **kwargs, + ): + """ + "Random" attention, as proposed for instance in BigBird_. + Random means in that case that each query can attend to a random set of keys. + This implementation is sparse-aware, meaning that the empty attention parts will not be represented in memory. + + Args: + r (float): the ratio in [0,1] of keys that the query can attend to + constant_masking (bool): if true, keep the same random set for all queries. + + .. _BigBird: https://arxiv.org/pdf/2007.14062.pdf + + """ + super().__init__() + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.causal = causal + self.r = r + self.rand_attention_mask: Optional[torch.Tensor] = None + self.constant_masking = constant_masking + self.force_sparsity = force_sparsity + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + self.requires_same_k_q_dimensions = True + + def _get_rand_mask(self, shape: torch.Size) -> torch.Tensor: + sparsity = 1 - self.r + mask = random_pattern(shape[1], sparsity=sparsity) + + if self.causal: + mask &= causal_1d_pattern(shape[1]) + + mask = sparsify(mask) if self.force_sparsity else maybe_sparsify(mask) + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + *args, + **kwargs, + ): + # Rand masking + if not self.constant_masking or self.rand_attention_mask is None: + self.rand_attention_mask = self._get_rand_mask(q.shape).to(q.device) + + # Mask-aware attention + if att_mask is not None: + if att_mask.dtype == torch.bool and isinstance( + self.rand_attention_mask, AttentionMask + ): + mask = self.rand_attention_mask + AttentionMask.from_bool(att_mask) + else: + if isinstance(att_mask, AttentionMask): + # Needed because & op not defined for SparseCS with AttentionMask + att_mask = att_mask.to_bool() + mask = self.rand_attention_mask & att_mask + else: + mask = self.rand_attention_mask + + # Handle q/k/v which would not fit the mask + seq_len = q.shape[-2] + q_, k_, v_ = map(lambda x: self._maybe_pad_sequence(x, mask), (q, k, v)) + + # Normal attention with the random mask + att = scaled_dot_product_attention( + q=q_, k=k_, v=v_, att_mask=mask, dropout=self.attn_drop + ) + + # Take into account an hypothetical padding + return att[:, :seq_len, :] diff --git a/pkgs/xformers/components/attention/scaled_dot_product.py b/pkgs/xformers/components/attention/scaled_dot_product.py new file mode 100644 index 0000000..16fd32a --- /dev/null +++ b/pkgs/xformers/components/attention/scaled_dot_product.py @@ -0,0 +1,134 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from typing import Optional, Union + +import torch +from torch import nn + +from xformers.components.attention import ( + Attention, + AttentionConfig, + AttentionMask, + register_attention, +) +from xformers.components.attention.core import scaled_dot_product_attention + +logger = logging.getLogger("xformers") + + +@dataclass +class ScaledDotProductConfig(AttentionConfig): + causal: Optional[bool] + seq_len: Optional[int] + to_seq_len: Optional[int] + + +@register_attention("scaled_dot_product", ScaledDotProductConfig) +class ScaledDotProduct(Attention): + r""" + Implementing the Scaled Dot-Product attention proposed in + `Attention is all you need`_, Vaswani et al. + + .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5 + """ + + mask: Optional[AttentionMask] + + def __init__( + self, + dropout: float = 0.0, + causal: bool = False, + seq_len: Optional[int] = None, + to_seq_len: Optional[int] = None, + *args, + **kwargs, + ): + super().__init__() + + self.attn_drop = nn.Dropout(dropout, inplace=False) + self.causal = causal + self.seq_len = seq_len + + if causal and seq_len is not None: + self.mask = AttentionMask.make_causal(seq_len, to_seq_len) + else: + self.mask = None + + # Properties specific to this attention mechanism + self.supports_attention_mask = True + self.supports_key_padding_mask = False + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_mask: Optional[Union[AttentionMask, torch.Tensor]] = None, + *args, + **kwargs, + ) -> torch.Tensor: + r""" + att_mask A 2D or 3D mask which ignores attention at certain positions. + + - If the mask is boolean, a value of True will keep the value, + while a value of False will mask the value. + + Key padding masks (dimension: batch x sequence length) and attention masks + (dimension: sequence length x sequence length OR batch x sequence length x sequence length) + can be combined and passed in here. Method maybe_merge_masks provided in the utils can be + used for that merging. + + - If the mask has the float type, then an additive mask is expected (masked values are -inf) + + """ + + # Convenience, create an attention mask if a tensor was passed + if att_mask is not None and isinstance(att_mask, torch.Tensor): + # By default we don't know of the causality, and a check would be expensive + att_mask = ( + AttentionMask.from_bool(att_mask) + if att_mask.dtype == torch.bool + else AttentionMask(att_mask, is_causal=False) + ) + + # Handle a possibly deferred causal mask handling + mask = self.mask + if self.causal and self.mask is None: + mask = AttentionMask.make_causal( + seq_len=q.shape[-2], + to_seq_len=q.shape[-2], + device=q.device, + dtype=q.dtype, + ) + + # Merge the optional causal mask and the user-provided mask + if mask is not None: + mask = mask.to(dtype=q.dtype, device=q.device) + + att_mask = att_mask + mask if att_mask is not None else mask + + # Try to handle a case where the sequence is smaller than the mask + if ( + att_mask is not None + and q.shape[-2] == k.shape[-2] + and q.shape[-2] < att_mask.shape[1] + ): + if isinstance(att_mask, AttentionMask): + att_mask = att_mask.make_crop(seq_len=q.shape[-2]) + else: + logger.error( + "Mismatching sparse attention mask and sequence length." + + " Please pad the inputs or adjust the attention mask" + ) + raise NotImplementedError + + # Attend: (B x nh, S, hs) x (B x nh, hs, S) -> (B x nh, S, S) + y = scaled_dot_product_attention( + q=q, k=k, v=v, att_mask=att_mask, dropout=self.attn_drop + ) + return y diff --git a/pkgs/xformers/components/attention/sparsity_config.py b/pkgs/xformers/components/attention/sparsity_config.py new file mode 100644 index 0000000..727a7ff --- /dev/null +++ b/pkgs/xformers/components/attention/sparsity_config.py @@ -0,0 +1,812 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +""" +The code has been adopted from DeepSpeed +(https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/sparse_attention/sparsity_config.py) +""" + +import random + +import torch + + +class SparsityConfig: + """Abstract Configuration class to store `sparsity configuration of a self attention layer`. + It contains shared property of different block-sparse sparsity patterns. However, each class + needs to extend it based on required property and functionality. + """ + + def __init__(self, num_heads, block_size=16, different_layout_per_head=False): + """Initialize the Sparsity Pattern Config. + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be + assigned a different sparsity layout; default is false and this will be satisfied + based on availability. + """ + + self.num_heads = num_heads + self.block_size = block_size + self.different_layout_per_head = different_layout_per_head + self.num_layout_heads = num_heads if different_layout_per_head else 1 + + def setup_layout(self, seq_len): + """Create layout tensor for the given sequence length + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) for sparsity layout + of all head; initialized with zero + """ + + if seq_len % self.block_size != 0: + raise ValueError( + f"Sequence Length, {seq_len}, needs to be dividable by Block size {self.block_size}!" + ) + num_blocks = seq_len // self.block_size + # TODO Currently we allocate layout per head; needs to be updated if heads share a single layout. + layout = torch.zeros( + (self.num_heads, num_blocks, num_blocks), dtype=torch.int64 + ) + return layout + + def check_and_propagate_first_head_layout(self, layout): + """If all heads require same sparsity layout, it propagate first head layout to all heads + Arguments: + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head + """ + + if not self.different_layout_per_head: + layout[1 : self.num_heads, :, :] = layout[0, :, :] + return layout + + +class DenseSparsityConfig(SparsityConfig): + """Configuration class to store `Dense` configuration. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison and + comprehension. + """ + + def __init__(self, num_heads, block_size=16, different_layout_per_head=False): + """Initialize the Dense Sparsity Pattern Config. + In reality, this is not sparse and all blocks are used. We keep it for the sake of comparison + and comprehension. + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: this is just for the sake of consistency with + other sparsity formats; can ignore it for DenseSparsityConfig + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + def make_layout(self, seq_len): + """Set 1 to all blocks of the layout meanins the pattern is dense; not sparse. + Arguments: + seq_len: required: an integer determining the underling sequence length; + must be <= max sequence length + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head; for dense everything is 1 + """ + + layout = self.setup_layout(seq_len) + layout[:, :, :] = 1 + return layout + + +class FixedSparsityConfig(SparsityConfig): + """Configuration class to store `Fixed` sparsity configuration. + For more details about this sparsity config, please see `Generative Modeling with + Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_local_blocks=4, + num_global_blocks=1, + attention="bidirectional", + horizontal_global_attention=False, + num_different_global_patterns=1, + ): + """Initialize `Fixed` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be + assigned a different sparsity layout; default is false and this will be satisfied + based on availability. + num_local_blocks: optional: an integer determining the number of blocks in local attention + window. + num_global_blocks: optional: an integer determining how many consecutive blocks in a local + window is used as the representative of the window for global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global + representative of a local window, also attend to all other blocks. This is valid only if + attention type is `bidirectional`. Looking at the attention matrix, that means global + attention not only includes the vertical blocks, but also horizontal blocks. + num_different_global_patterns: optional: an integer determining number of different global + attentions layouts. While global attention can be fixed by which block/s are representative + of any local window, since there are multi-heads, each head can use a different global representative. + For example, with 4 blocks local window and global attention size of 1 block, we can have 4 different + versions in which the first, Second, third, or forth block of each local window can be global + representative of that window. This parameter determines how many of such patterns we want. + Of course, there is a limitation based on num_local_blocks and num_global_blocks. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_local_blocks = num_local_blocks + + if num_local_blocks % num_global_blocks != 0: + raise ValueError( + f"""Number of blocks in a local window, {num_local_blocks}, + must be dividable by number of global blocks, {num_global_blocks}!""" + ) + self.num_global_blocks = num_global_blocks + + if attention != "unidirectional" and attention != "bidirectional": + raise NotImplementedError( + 'only "uni/bi-directional" attentions are supported for now!' + ) + self.attention = attention + + if attention != "bidirectional" and horizontal_global_attention: + raise ValueError( + 'only "bi-directional" attentions can support horizontal global attention!' + ) + self.horizontal_global_attention = horizontal_global_attention + + if num_different_global_patterns > 1 and not different_layout_per_head: + raise ValueError( + """Number of different layouts cannot be more than one when you have set a single layout + for all heads! Set different_layout_per_head to True.""" + ) + if num_different_global_patterns > (num_local_blocks // num_global_blocks): + raise ValueError( + f"""Number of layout versions (num_different_global_patterns), {num_different_global_patterns}, + cannot be larger than number of local window blocks divided by number of global blocks, + {num_local_blocks} / {num_global_blocks} = {num_local_blocks//num_global_blocks}!""" + ) + self.num_different_global_patterns = num_different_global_patterns + + def set_local_layout(self, h, layout): + """Sets local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local layout is set + """ + + num_blocks = layout.shape[1] + for i in range(0, num_blocks, self.num_local_blocks): + end = min(i + self.num_local_blocks, num_blocks) + for row in range(i, end): + for col in range( + i, (row + 1 if self.attention == "unidirectional" else end) + ): + layout[h, row, col] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Currently we set global blocks starting from the last block of a local window to the first one. + That means if a local window consists of 4 blocks and global attention size is one block, we use + block #4 in each local window as global. If we have different layout per head, then other heads + will get #3, #2, and #1. And if we have more heads (and different layout has set) than num of global + attentions, multiple head may have same global attentions. + Note) if horizontal_global_attention is set, global blocks will be set both horizontally and + vertically. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + first_global_block_idx = ( + self.num_local_blocks + - (1 + h % self.num_different_global_patterns) * self.num_global_blocks + ) + + # set all global blocks except the last one if (in last local window) + end = num_blocks - (num_blocks % self.num_local_blocks) + for i in range(first_global_block_idx, end, self.num_local_blocks): + + # vertical global attention + first_row = 0 if self.attention == "bidirectional" else i + # (((i // self.num_local_blocks) + 1) * self.num_local_blocks) + # if (first_row < num_blocks): + layout[h, first_row:, i : i + self.num_global_blocks] = 1 + + # horizontal global attention; only in bidirectional attention + if self.horizontal_global_attention: + layout[h, i : i + self.num_global_blocks, :] = 1 + + # set last global blocks; handle possible short last local window + if end < num_blocks: + start = min( + end + first_global_block_idx, num_blocks - self.num_global_blocks + ) + end = start + self.num_global_blocks + + # vertical global attention + first_row = 0 if self.attention == "bidirectional" else start + # (((start // self.num_local_blocks) + 1) * self.num_local_blocks) + # if (first_row < num_blocks): + layout[h, first_row:, start:end] = 1 + + # horizontal global attention + if self.horizontal_global_attention: + layout[h, start:end, :] = 1 + return layout + + def make_layout(self, seq_len): + """Generates `Fixed` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Fixed` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_local_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class VariableSparsityConfig(SparsityConfig): + """Configuration class to store `Variable` sparsity configuration. + This layout is an extension of FixedSparsityConfig in which: + - user can set random layout; default value is zero means no random block + - user can provide a list of local block sizes + - user can provide a list of global block indices. + For more details about `Fixed` sparsity config, please see `Generative Modeling with + Sparse Transformers`: https://arxiv.org/abs/1904.10509; this has been customized. + This class extends parent class of `SparsityConfig` and customizes it for `Fixed` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_random_blocks=0, + local_window_blocks=[4], + global_block_indices=[0], + global_block_end_indices=None, + attention="bidirectional", + horizontal_global_attention=False, + ): + """Initialize `Variable` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of sparse + self-attention is based on blocked sparse matrices. In which this parameter defines + size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a + different sparsity layout; default is false and this will be satisfied based on + availability. Currently this sparsity config can only assign single layout to all heads; + needs to be extended for different layout per head. + num_random_blocks: optional: an integer determining the number of random blocks in each block row. + local_window_blocks: optional: a list of integers determining the number of blocks in each + local attention window. It assumes first number determines # of blocks in the first local + window, second the second window, ..., and the last number determines the number of blocks + in the remaining local windows. + global_block_indices: optional: a list of integers determining which blocks are considered + as global attention. Given indices, determine the blocks that all other token blocks + attend to and they attend to all other token blocks. Default value is only index 0. + Notice that if global_block_end_indices parameter is set, this parameter is used as + starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global + window blocks. By default this is not used. But if it is set, it must have the same size + of global_block_indices parameter, and combining this two parameters, for each index i, + blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are + considered as global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + horizontal_global_attention: optional: a boolean determining if blocks that are global + representative of a local window, also attend to all other blocks. This is valid only if + attention type is `bidirectional`. Looking at the attention matrix, that means global + attention not only includes the vertical blocks, but also horizontal blocks. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_random_blocks = num_random_blocks + self.local_window_blocks = local_window_blocks + self.global_block_indices = global_block_indices + + if global_block_end_indices is not None: + if len(global_block_indices) != len(global_block_end_indices): + raise ValueError( + f"""Global block start indices length, {len(global_block_indices)}, must be same as + global block end indices length, {len(global_block_end_indices)}!""" + ) + for _, (start_idx, end_idx) in enumerate( + zip(global_block_indices, global_block_end_indices) + ): + if start_idx >= end_idx: + raise ValueError( + f"""Global block start index, {start_idx}, must be smaller than global block end + index, {end_idx}!""" + ) + self.global_block_end_indices = global_block_end_indices + + if attention != "unidirectional" and attention != "bidirectional": + raise NotImplementedError( + 'only "uni/bi-directional" attentions are supported for now!' + ) + self.attention = attention + + if attention != "bidirectional" and horizontal_global_attention: + raise ValueError( + 'only "bi-directional" attentions can support horizontal global attention!' + ) + self.horizontal_global_attention = horizontal_global_attention + + def set_random_layout(self, h, layout): + """Sets random attention layout used by the given head in the sparse attention. + Note) By default, it assumes there will be a unique random block layout for all heads; unless + `different_layout_per_head` parameter is set in which each head can have a different random + layout. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which random layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_random_blocks: + raise ValueError( + f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" + ) + for row in range(0, num_blocks): + rnd_cols = random.sample(range(0, num_blocks), self.num_random_blocks) + layout[h, row, rnd_cols] = 1 + return layout + + def set_local_layout(self, h, layout): + """Sets local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local layout is set + """ + + num_blocks = layout.shape[1] + start_block_idx = 0 + end_block_idx = 0 + for block_size in self.local_window_blocks: + end_block_idx += block_size + end_block_idx = min(end_block_idx, num_blocks) + for row in range(start_block_idx, end_block_idx): + for col in range( + start_block_idx, + (row + 1 if self.attention == "unidirectional" else end_block_idx), + ): + layout[h, row, col] = 1 + start_block_idx += block_size + + # if there is any remaining not attended part, use the lats local window block size as local + # window for the remaining applicable local windows + for i in range(start_block_idx, num_blocks, block_size): + end_block_idx = min(i + block_size, num_blocks) + for row in range(i, end_block_idx): + for col in range( + i, + (row + 1 if self.attention == "unidirectional" else end_block_idx), + ): + layout[h, row, col] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if self.global_block_end_indices is None: + for idx in self.global_block_indices: + # if global block idx is in the range of the sequence blocks + if idx < num_blocks: + # global rows + if self.horizontal_global_attention: + layout[h, idx, :] = 1 + + # global columns + first_row = 0 if self.attention == "bidirectional" else idx + layout[h, first_row:, idx] = 1 + else: + for _, (start_idx, end_idx) in enumerate( + zip(self.global_block_indices, self.global_block_end_indices) + ): + # if global block idx is in the range of the sequence blocks + if start_idx < num_blocks: + end_idx = min(end_idx, num_blocks) + # global rows + if self.horizontal_global_attention: + layout[h, start_idx:end_idx, :] = 1 + + # global columns + first_row = 0 if self.attention == "bidirectional" else start_idx + layout[h, first_row:, start_idx:end_idx] = 1 + return layout + + def make_layout(self, seq_len): + """Generates `Variable` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `Variable` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_random_layout(h, layout) + layout = self.set_local_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class BigBirdSparsityConfig(SparsityConfig): + """Configuration class to store `BigBird` sparsity configuration. + For more details about this sparsity config, please see `Big Bird: Transformers for + Longer Sequences`: https://arxiv.org/pdf/2007.14062.pdf + This class extends parent class of `SparsityConfig` and customizes it for `BigBird` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_random_blocks=1, + num_sliding_window_blocks=3, + num_global_blocks=1, + attention="bidirectional", + ): + """Initialize the BigBird Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of + sparse self-attention is based on blocked sparse matrices. In which this parameter + defines size of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned + a different sparsity layout; default is false and this will be satisfied based on + availability. + num_random_blocks: optional: an integer determining the number of random blocks in each + block row. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding + local attention window. + num_global_blocks: optional: an integer determining how many consecutive blocks, starting + from index 0, are considered as global attention. Global block tokens will be attended + by all other block tokens and will attend to all other block tokens as well. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_random_blocks = num_random_blocks + self.num_sliding_window_blocks = num_sliding_window_blocks + self.num_global_blocks = num_global_blocks + + if attention != "unidirectional" and attention != "bidirectional": + raise NotImplementedError( + 'only "uni/bi-directional" attentions are supported for now!' + ) + self.attention = attention + + def set_random_layout(self, h, layout): + """Sets random attention layout used by the given head in the sparse attention. + Note) By default, it assumes there will be a unique random block layout for all heads; unless + `different_layout_per_head` parameter is set in which each head can have a different random layout. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which random layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_random_blocks: + raise ValueError( + f"""Number of random blocks, {self.num_random_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" + ) + + for row in range(0, num_blocks): + sample_range = ( + range(0, num_blocks) + if self.attention == "bidirectional" + else range(0, row + 1) + ) + rnd_cols = random.sample(sample_range, self.num_random_blocks) + layout[h, row, rnd_cols] = 1 + return layout + + def set_sliding_window_layout(self, h, layout): + """Sets sliding local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which local sliding window layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_sliding_window_blocks: + raise ValueError( + f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller than + overall number of blocks in a row, {num_blocks}!""" + ) + + w = self.num_sliding_window_blocks // 2 + for row in range(0, num_blocks): + start = max(0, row - w) + end = min(row + w + 1, num_blocks) + layout[h, row, start:end] = 1 + return layout + + def set_global_layout_itc(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout + of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_global_blocks: + raise ValueError( + f"""Number of global blocks, {self.num_global_blocks}, must be smaller than overall number + of blocks in a row, {num_blocks}!""" + ) + + # global rows + layout[h, 0 : self.num_global_blocks, :] = 1 + + # global columns + layout[h, :, 0 : self.num_global_blocks] = 1 + + if self.attention == "unidirectional": + # zero out anything attending to the future + layout = torch.tril(layout) + + return layout + + def make_layout(self, seq_len): + """Generates `BigBird` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BigBird` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_random_layout(h, layout) + layout = self.set_sliding_window_layout(h, layout) + layout = self.set_global_layout_itc(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout + + +class BSLongformerSparsityConfig(SparsityConfig): + """Configuration class to store edited `Longformer` sparsity configuration. + Note) this is a block-sparse version of the Longformer which is slightly different than original + Longformer; which is element-wise sparsity. + For more details about this sparsity config, please see `Longformer: + The Long-Document Transformer`: https://arxiv.org/pdf/2004.05150.pdf + This class extends parent class of `SparsityConfig` and customizes it for `Longformer` sparsity. + """ + + def __init__( + self, + num_heads, + block_size=16, + different_layout_per_head=False, + num_sliding_window_blocks=3, + global_block_indices=[0], + global_block_end_indices=None, + attention="bidirectional", + ): + """Initialize the edited `Longformer` Sparsity Pattern Config. + For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial + Arguments: + num_heads: required: an integer determining number of attention heads of the layer. + block_size: optional: an integer determining the block size. Current implementation of sparse + self-attention is based on blocked sparse matrices. In which this parameter defines size + of such blocks, `Block X Block`. + different_layout_per_head: optional: a boolean determining if each head should be assigned a + different sparsity layout; default is false and this will be satisfied based on + availability. + num_sliding_window_blocks: optional: an integer determining the number of blocks in sliding + local attention window. + global_block_indices: optional: a list of integers determining which blocks are considered + as global attention. Given indices, determine the blocks that all other token blocks + attend to and they attend to all other token blocks. Default value is only index 0. + Notice that if global_block_end_indices parameter is set, this parameter is used as + starting index of each global window. + global_block_end_indices: optional: a list of integers determining end indices of global + window blocks. By default this is not used. But if it is set, it must have the same size + of global_block_indices parameter, and combining this two parameters, for each index i, + blocks from global_block_indices[i] to global_block_end_indices[i] (exclusive) are + considered as global attention. + attention: optional: a string determining attention type. Attention can be `unidirectional`, + such as autoregressive models, in which tokens attend only to tokens appear before them + in the context. Considering that, the upper triangular of attention matrix is empty as + above figure. Or it can be `bidirectional`, such as BERT, in which tokens can attend to + any other tokens before or after them. Then, the upper triangular part of the attention + matrix is mirror of the lower triangular in the above figure. + """ + + super().__init__(num_heads, block_size, different_layout_per_head) + + self.num_sliding_window_blocks = num_sliding_window_blocks + self.global_block_indices = global_block_indices + self.attention = attention + + if global_block_end_indices is not None: + if len(global_block_indices) != len(global_block_end_indices): + raise ValueError( + f"""Global block start indices length, {len(global_block_indices)}, must be same as + global block end indices length, {len(global_block_end_indices)}!""" + ) + for _, (start_idx, end_idx) in enumerate( + zip(global_block_indices, global_block_end_indices) + ): + if start_idx >= end_idx: + raise ValueError( + f"""Global block start index, {start_idx}, must be smaller than global block end + index, {end_idx}!""" + ) + self.global_block_end_indices = global_block_end_indices + + def set_sliding_window_layout(self, h, layout): + """Sets sliding local attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity layout + of all head in which local sliding window layout is set + """ + + num_blocks = layout.shape[1] + if num_blocks < self.num_sliding_window_blocks: + raise ValueError( + f"""Number of sliding window blocks, {self.num_sliding_window_blocks}, must be smaller + than overall number of blocks in a row, {num_blocks}!""" + ) + + w = self.num_sliding_window_blocks // 2 + for row in range(0, num_blocks): + start = max(0, row - w) + end = min(row + w + 1, num_blocks) + layout[h, row, start:end] = 1 + return layout + + def set_global_layout(self, h, layout): + """Sets global attention layout used by the given head in the sparse attention. + Arguments: + h: required: an integer determining head index + layout: required: a tensor of dimension (num_heads, num_blocks, num_blocks) containing + sparsity layout of all head; may not be completely set at this step + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing sparsity + layout of all head in which global layout is set + """ + + num_blocks = layout.shape[1] + if self.global_block_end_indices is None: + for idx in self.global_block_indices: + # if global block idx is in the range of the sequence blocks + if idx < num_blocks: + # global rows + layout[h, idx, :] = 1 + + # global columns + layout[h, :, idx] = 1 + else: + for _, (start_idx, end_idx) in enumerate( + zip(self.global_block_indices, self.global_block_end_indices) + ): + # if global block idx is in the range of the sequence blocks + if start_idx < num_blocks: + end_idx = min(end_idx, num_blocks) + # global rows + layout[h, start_idx:end_idx, :] = 1 + + # global columns + layout[h, :, start_idx:end_idx] = 1 + if self.attention == "unidirectional": + layout = torch.tril(layout) + return layout + + def make_layout(self, seq_len): + """Generates edited `Longformer` sparsity layout used by each head in the sparse attention. + Arguments: + seq_len: required: an integer determining number of attention heads of the layer. + Return: + layout: a tensor of dimension (num_heads, num_blocks, num_blocks) containing `BSLongformer` + sparsity layout of all head + """ + + layout = self.setup_layout(seq_len) + for h in range(0, self.num_layout_heads): + layout = self.set_sliding_window_layout(h, layout) + layout = self.set_global_layout(h, layout) + + layout = self.check_and_propagate_first_head_layout(layout) + return layout diff --git a/pkgs/xformers/components/attention/utils.py b/pkgs/xformers/components/attention/utils.py new file mode 100644 index 0000000..d6bb06a --- /dev/null +++ b/pkgs/xformers/components/attention/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch + + +# Reshapes key padding mask from (batch_size, src_len) -> (batch_size * num_heads 1, src_len) +def reshape_key_padding_mask( + key_padding_mask: torch.Tensor, batched_dim: int +) -> torch.Tensor: + assert key_padding_mask.ndim == 2 + batch_size, src_len = key_padding_mask.size() + num_heads = batched_dim // batch_size + return _reshape_key_padding_mask(key_padding_mask, batch_size, src_len, num_heads) + + +def _reshape_key_padding_mask( + key_padding_mask: torch.Tensor, batch_size: int, src_len: int, num_heads: int +) -> torch.Tensor: + assert key_padding_mask.shape == (batch_size, src_len) + key_padding_mask = ( + key_padding_mask.view(batch_size, 1, 1, src_len) + .expand(-1, num_heads, -1, -1) + .reshape(batch_size * num_heads, 1, src_len) + ) + return key_padding_mask + + +# Combine the attention mask and key padding mask into a single mask +# Taken from https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py +# Additive masking not yet supported +def maybe_merge_masks( + att_mask: Optional[torch.Tensor], + key_padding_mask: Optional[torch.Tensor], + batch_size: int, + src_len: int, + num_heads: int, + tgt_len: Optional[int] = None, +) -> Optional[torch.Tensor]: + if tgt_len is None: + tgt_len = src_len + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, src_len) + key_padding_mask = _reshape_key_padding_mask( + key_padding_mask, batch_size, src_len, num_heads + ) + if att_mask is None: + # make sure dimensions of key padding mask are the same as those expected for att_mask + att_mask = key_padding_mask.expand(-1, tgt_len, -1) + # Assumption is that False means to mask. + elif att_mask.dtype == torch.bool: + att_mask = att_mask.logical_and(key_padding_mask) + else: + att_mask = att_mask.masked_fill(~key_padding_mask, float("-inf")) + + return att_mask + + +# Assumes that matrix passed in has had softmax applied to it. +def iterative_pinv(softmax_mat: torch.Tensor, n_iter=6, pinverse_original_init=False): + """ + Computing the Moore-Penrose inverse. + Use an iterative method from (Razavi et al. 2014) to approximate the Moore-Penrose inverse via efficient + matrix-matrix multiplications. + """ + + i = torch.eye( + softmax_mat.size(-1), device=softmax_mat.device, dtype=softmax_mat.dtype + ) + k = softmax_mat + + # The entries of K are positive and ||K||_{\infty} = 1 due to softmax + if pinverse_original_init: + # This original implementation is more conservative to compute coefficient of Z_0. + v = 1 / torch.max(torch.sum(k, dim=-2)) * k.transpose(-1, -2) + else: + # This is the exact coefficient computation, 1 / ||K||_1, of initialization of Z_0, leading to faster + # convergence. + v = ( + 1 + / torch.max(torch.sum(k, dim=-2), dim=-1).values[:, None, None] + * k.transpose(-1, -2) + ) + + for _ in range(n_iter): + kv = torch.matmul(k, v) + v = torch.matmul( + 0.25 * v, + 13 * i - torch.matmul(kv, 15 * i - torch.matmul(kv, 7 * i - kv)), + ) + return v + + +def bool_mask_to_additive( + mask: torch.Tensor, dtype: Optional[torch.dtype] = torch.float32 +) -> torch.Tensor: + assert ( + mask.dtype == torch.bool + ), "This util is meant to convert in between bool masks and additive ones" + + mask_ = torch.zeros_like(mask, dtype=dtype) + mask_[~mask] = float("-inf") + return mask_ diff --git a/pkgs/xformers/components/attention/visual.py b/pkgs/xformers/components/attention/visual.py new file mode 100644 index 0000000..6ea81f4 --- /dev/null +++ b/pkgs/xformers/components/attention/visual.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from xformers.components.attention import Attention, AttentionConfig, register_attention + + +@dataclass +class VisualAttentionConfig(AttentionConfig): + dim_model: int # dimension of the input sequence + + +class LKA(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv_spatial = nn.Conv2d( + dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3 + ) + self.conv1 = nn.Conv2d(dim, dim, 1) + + def forward(self, x: torch.Tensor): + u = x.clone() + attn = self.conv0(x) + attn = self.conv_spatial(attn) + attn = self.conv1(attn) + + return u * attn + + +@register_attention("visual", VisualAttentionConfig) +class Visual(Attention): + def __init__( + self, + dim_model: int, + *_, + **__, + ): + """ + Large kernel attention mechanism, as proposed in `Visual Attention Network`_, Guo et al (2022). + The original notation is tentatively kept as is. See https://github.com/Visual-Attention-Network + for the reference implementation + + .. Note: compared to the paper, this block contains the LKA (Large Kernel Attention) + and the prior and posterior transformations (Conv2d and activation) + + .. _`Visual Attention Network` : https://arxiv.org/pdf/2202.09741.pdf + """ + super().__init__() + + self.block = nn.Sequential( + nn.Conv2d(dim_model, dim_model, 1), + nn.GELU(), + LKA(dim_model), + nn.Conv2d(dim_model, dim_model, 1), + ) + + # MHA related flags: + self.requires_same_k_q_dimensions = ( + True # This mechanism only really supports self attention + ) + self.supports_attention_mask = False + self.requires_skip_multi_head = ( + True # This mechanism skips the multihead attention altogether + ) + self.requires_squared_context = ( + True # Recovering the 2D structure from context assumes squared content + ) + + self.requires_input_projection = ( + False # This mechanism does not require that the MHA projects inputs + ) + + def forward(self, q: torch.Tensor, *_, **__): + # Expose the 2D token structure + B, HW, C = q.shape + H = int(math.sqrt(HW)) + assert H * H == HW + + x = q.transpose(-2, -1).reshape(B, C, H, H) + + # Large kernel attention + residual = x.clone() + x = self.block(x) + x = x + residual + + # Get back to B HW C + return x.flatten(2, 3).transpose(-2, -1) diff --git a/pkgs/xformers/components/feedforward/__init__.py b/pkgs/xformers/components/feedforward/__init__.py new file mode 100644 index 0000000..7d6e061 --- /dev/null +++ b/pkgs/xformers/components/feedforward/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from pathlib import Path +from typing import Any, Callable, Dict, Set, Union + +from xformers.utils import ( + generate_matching_config, + get_registry_decorator, + import_all_modules, +) + +from .base import Feedforward, FeedforwardConfig # noqa + +# CREDITS: Classy Vision registry mechanism + +FEEDFORWARD_REGISTRY: Dict[str, Any] = {} +FEEDFORWARD_CLASS_NAMES: Set[str] = set() + + +def build_feedforward(config: Union[Dict[str, Any], FeedforwardConfig]): + """Builds a feedforward from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_feedforward", + "foo": "bar"}` will find a class that was registered as "my_feedforward" + (see :func:`register_feedforward`) and call .from_config on it.""" + + if not isinstance(config, FeedforwardConfig): + config_instance = generate_matching_config( + config, FEEDFORWARD_REGISTRY[config["name"]].config + ) + else: + config_instance = config + + return FEEDFORWARD_REGISTRY[config_instance.name].constructor.from_config( + config_instance + ) + + +"""Registers a Feedforward subclass. + + This decorator allows xFormers to instantiate a subclass of Feedforward + from a configuration file, even if the class itself is not part of the + xFormers framework. To use it, apply this decorator to a Feedforward + subclass, like this: + + .. code-block:: python + + @dataclass + class MyConfig: + ... + + @register_feedforward('my_ff', MyConfig) + class MyFeedforward(Feedforward): + ... + + To instantiate a feedforward from a configuration file, see :func:`build_feedforward`.""" +register_feedforward: Callable[ + [str, Any], Callable[[Any], Any] +] = get_registry_decorator( + FEEDFORWARD_REGISTRY, FEEDFORWARD_CLASS_NAMES, Feedforward, FeedforwardConfig +) + +try: + from .fused_mlp import FusedMLP # noqa + + _fused_mlp_available = True +except ImportError: + _fused_mlp_available = False +from .mlp import MLP # noqa + +__all__ = [ + "MLP", + "Feedforward", + "build_feedforward", + "register_feedforward", +] + +if _fused_mlp_available: + __all__ += ["FusedMLP"] + +# automatically import any Python files in the directory +import_all_modules(str(Path(__file__).parent), "xformers.components.feedforward") diff --git a/pkgs/xformers/components/feedforward/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/components/feedforward/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..1a55d17 Binary files /dev/null and b/pkgs/xformers/components/feedforward/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/components/feedforward/__pycache__/base.cpython-310.pyc b/pkgs/xformers/components/feedforward/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..dc9c071 Binary files /dev/null and b/pkgs/xformers/components/feedforward/__pycache__/base.cpython-310.pyc differ diff --git a/pkgs/xformers/components/feedforward/__pycache__/conv_mlp.cpython-310.pyc b/pkgs/xformers/components/feedforward/__pycache__/conv_mlp.cpython-310.pyc new file mode 100644 index 0000000..1d20cdb Binary files /dev/null and b/pkgs/xformers/components/feedforward/__pycache__/conv_mlp.cpython-310.pyc differ diff --git a/pkgs/xformers/components/feedforward/__pycache__/fused_mlp.cpython-310.pyc b/pkgs/xformers/components/feedforward/__pycache__/fused_mlp.cpython-310.pyc new file mode 100644 index 0000000..0a1d887 Binary files /dev/null and b/pkgs/xformers/components/feedforward/__pycache__/fused_mlp.cpython-310.pyc differ diff --git a/pkgs/xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-310.pyc b/pkgs/xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-310.pyc new file mode 100644 index 0000000..8753b3a Binary files /dev/null and b/pkgs/xformers/components/feedforward/__pycache__/mixture_of_experts.cpython-310.pyc differ diff --git a/pkgs/xformers/components/feedforward/__pycache__/mlp.cpython-310.pyc b/pkgs/xformers/components/feedforward/__pycache__/mlp.cpython-310.pyc new file mode 100644 index 0000000..95cbfb2 Binary files /dev/null and b/pkgs/xformers/components/feedforward/__pycache__/mlp.cpython-310.pyc differ diff --git a/pkgs/xformers/components/feedforward/base.py b/pkgs/xformers/components/feedforward/base.py new file mode 100644 index 0000000..3c483ce --- /dev/null +++ b/pkgs/xformers/components/feedforward/base.py @@ -0,0 +1,53 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABCMeta, abstractmethod +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch.nn as nn + +from xformers.components import Activation + +Self = TypeVar("Self", bound="Feedforward") + + +@dataclass +class FeedforwardConfig: + name: str + dim_model: int + dropout: float + activation: Activation + + +# Define the common interface, every feedforward block needs to derive from it +class Feedforward(nn.Module, metaclass=ABCMeta): + @abstractmethod + def __init__( + self, + dim_model: Optional[int] = None, + dropout: Optional[float] = None, + activation: Optional[Activation] = None, + *args, + **kwargs, + ): + super().__init__() + + # This feedforward requires a CUDA accelerator + self.requires_cuda = False + + # This feedforward requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = False + + @classmethod + def from_config(cls: Type[Self], config: FeedforwardConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) diff --git a/pkgs/xformers/components/feedforward/conv_mlp.py b/pkgs/xformers/components/feedforward/conv_mlp.py new file mode 100644 index 0000000..8952119 --- /dev/null +++ b/pkgs/xformers/components/feedforward/conv_mlp.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: Largely reusing the code from the reference VAN implementation +# see https://github.com/Visual-Attention-Network + +import math +from dataclasses import dataclass +from typing import Optional + +import torch.nn as nn + +from xformers.components import Activation, build_activation +from xformers.components.feedforward import Feedforward, FeedforwardConfig + +from . import register_feedforward + + +@dataclass +class ConvMlpConfig(FeedforwardConfig): + hidden_layer_multiplier: int + dim_model: int + dim_model_out: Optional[int] + act_layer: Activation + dropout: float + + +@register_feedforward("Conv2DFeedforward", ConvMlpConfig) +class Conv2DFeedforward(Feedforward): + """ + A Convolutional feed-forward network, as proposed in VAN_ (Vision Attention Network, Guo et al.) + + .. _VAN: https://arxiv.org/pdf/2202.09741.pdf + """ + + def __init__( + self, + dim_model: int, + hidden_layer_multiplier: int = 1, + dim_model_out: Optional[int] = None, + activation: Activation = Activation.GeLU, + dropout=0.0, + *args, + **kwargs, + ): + super().__init__() + out_features = dim_model_out or dim_model + hidden_features = hidden_layer_multiplier * dim_model + + self.conv_mlp = nn.Sequential( + nn.Conv2d(dim_model, hidden_features, 1), + nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features, + ), + build_activation(activation), + nn.Conv2d(hidden_features, out_features, 1), + nn.Dropout(dropout), + ) + + # This feedforward requires a context length which is squared, often due to 2D pooling + self.requires_squared_context = True + + def init_weights(self, **kwargs): + # Follow the original init, but also make it possible to initialize from the outside + def init_module(m: nn.Module): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + + self.apply(init_module) + + def forward(self, x): + # The conv layers expect NCHW, we have NLC by default + B, L, C = x.shape + HW = int(math.sqrt(x.shape[-2])) + assert HW**2 == L, "Conv2DFeedforward requires squared context lengths" + + x = x.reshape((B, HW, HW, C)).swapdims(1, -1) + + # The actual FW, including the 2d convolutions + x = self.conv_mlp(x) + + # back to NLC + x = x.transpose(1, -1) + return x.flatten(1, 2) diff --git a/pkgs/xformers/components/feedforward/fused_mlp.py b/pkgs/xformers/components/feedforward/fused_mlp.py new file mode 100644 index 0000000..35b265f --- /dev/null +++ b/pkgs/xformers/components/feedforward/fused_mlp.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from xformers.components import Activation +from xformers.components.feedforward import ( + Feedforward, + FeedforwardConfig, + register_feedforward, +) + +logger = logging.getLogger("xformers") + + +if torch.cuda.is_available(): + try: + from xformers.triton import FusedDropoutBias + + @dataclass + class FusedMlpConfig(FeedforwardConfig): + hidden_layer_multiplier: int + + @register_feedforward("FusedMLP", FusedMlpConfig) + class FusedMLP(Feedforward): + """ + A MLP using fused linear layers. + """ + + def __init__( + self, + dim_model: int, + dropout: float, + activation: Activation, + hidden_layer_multiplier: int, + bias: bool = True, + *args, + **kwargs, + ): + super().__init__() + + dim_mlp = hidden_layer_multiplier * dim_model + + self.mlp = nn.Sequential( + nn.Linear( + in_features=dim_model, out_features=dim_mlp, bias=False + ), # bias is handled in the next layer + # pyre-ignore[16]: TODO(T101400990): Pyre did not recognize + # the `FusedLinear` import. + FusedDropoutBias( + p=dropout, + bias_shape=dim_mlp if bias else None, + activation=activation, + ), + nn.Linear( + in_features=dim_mlp, out_features=dim_model, bias=False + ), # bias is handled in the next layer + # pyre-ignore[16]: TODO(T101400990): Pyre did not recognize + # the `FusedLinear` import. + FusedDropoutBias( + p=dropout, + bias_shape=dim_model if bias else None, + activation=None, + ), + ) + self.requires_cuda = True + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.mlp(inputs) + + except ImportError: + logger.warning("Triton is not available, FusedMLP will not be enabled.") diff --git a/pkgs/xformers/components/feedforward/mixture_of_experts.py b/pkgs/xformers/components/feedforward/mixture_of_experts.py new file mode 100644 index 0000000..b6ab184 --- /dev/null +++ b/pkgs/xformers/components/feedforward/mixture_of_experts.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, Optional, Union + +import torch + +from xformers.components import Activation +from xformers.components.feedforward import ( + Feedforward, + FeedforwardConfig, + register_feedforward, +) + +logger = logging.getLogger("xformers") + + +_is_fairscale_available = True + +try: + import torch.distributed as dist + from fairscale.nn import MOELayer, Top2Gate # type: ignore + + from xformers.components.feedforward import MLP + +except ImportError: + logger.warning( + "Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed." + " Please install them if you would like to use MoE" + ) + _is_fairscale_available = False + + +if _is_fairscale_available: + + # Credits: initially implemented in FairScale for sanity checking + class RoundRobinGate(torch.nn.Module): + def __init__(self, model_dim, num_experts): + super().__init__() + self.model_dim = model_dim + self.num_experts = num_experts + + def forward(self, input): + s = input.shape[0] + assert s % self.num_experts == 0, f"{s} % {self.num_experts} != 0" + capacity = 2 * s // self.num_experts + output = torch.zeros( + s, self.num_experts, capacity, dtype=input.dtype, device=input.device + ) + for i in range(s): + output[i, i % self.num_experts, i // self.num_experts] = 1.0 + return 0.0, output, output.bool() + + class GateConfig(str, Enum): + RoundRobin = "round_robin" + Top2 = "top_2" + # Other gating techniques could be exposed here + + @dataclass + class MoEConfig(FeedforwardConfig): + number_of_experts: int + gate: GateConfig + number_of_local_experts: Optional[int] = None + expert_constructor: Optional[Any] = None + hidden_layer_multiplier: Optional[int] = None + group: Optional[Any] = None + + @register_feedforward("MixtureOfExperts", MoEConfig) + class MixtureOfExperts(Feedforward): + """ + A MLP variant which uses the "Mixture of Experts" paradigm, as described in Gshard_. + xFormers uses the FairScale_ implementation under the hood. + + .. warning: Please note that most of the benefits of MoE are present in a distributed training environmentt + + .. _Gshard: https://arxiv.org/pdf/2006.16668.pdf + .. _FairScale: https://github.com/facebookresearch/fairscale/ + """ + + def __init__( + self, + dim_model: int, + dropout: float, + activation: Activation, + number_of_experts: int, + gate: Union[GateConfig, torch.nn.Module], + number_of_local_experts: Optional[int] = None, + expert_constructor: Optional[Callable[[], torch.nn.Module]] = None, + hidden_layer_multiplier: Optional[int] = None, + group: Optional[Any] = None, + *_, + **__, + ): + super().__init__() + + # Handle a possibly uninitialized process group + assert ( + dist.is_initialized() + ), "Mixture of Experts require torch distributed to be initialized" + + if number_of_local_experts is not None: + assert number_of_experts >= number_of_local_experts + else: + if dist.get_world_size() == 1: + logger.warning("Local experts no specified but world size of 1") + logger.warning("Assuming that all experts are local") + number_of_local_experts = number_of_experts + else: + number_of_local_experts = 1 + + # Programatically handle the gating technique + if not isinstance(gate, torch.nn.Module): + gate_constructor = { + GateConfig.RoundRobin: RoundRobinGate, + GateConfig.Top2: Top2Gate, + }[gate] + + self.gate = gate_constructor(dim_model, number_of_experts) + else: + self.gate = gate + + # Programatically handle the experts + if expert_constructor is None: + + multiplier = ( + hidden_layer_multiplier + if hidden_layer_multiplier is not None + else 4 + ) + + def expert_constructor() -> torch.nn.Module: + return MLP(dim_model, dropout, activation, multiplier) + + assert expert_constructor is not None + + local_experts = torch.nn.ModuleList( + [expert_constructor() for _ in range(number_of_local_experts)] + ) + + self.moe = MOELayer(gate=self.gate, experts=local_experts, group=group) + + self.requires_cuda = True + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # FairScale MoE assumes that the dimensions are [S, B, E] + # xFormers assumes [B, S, E] + return self.moe(inputs.movedim(0, 1)).movedim(0, 1) diff --git a/pkgs/xformers/components/feedforward/mlp.py b/pkgs/xformers/components/feedforward/mlp.py new file mode 100644 index 0000000..fefb328 --- /dev/null +++ b/pkgs/xformers/components/feedforward/mlp.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from xformers.components import Activation, build_activation +from xformers.components.feedforward import Feedforward, FeedforwardConfig + +from . import register_feedforward + + +@dataclass +class MlpConfig(FeedforwardConfig): + hidden_layer_multiplier: int + bias: bool + + +@register_feedforward("MLP", MlpConfig) +class MLP(Feedforward): + def __init__( + self, + dim_model: int, + dropout: float, + activation: Activation, + hidden_layer_multiplier: int, + bias: bool = True, + *args, + **kwargs, + ): + super().__init__() + dim_mlp = hidden_layer_multiplier * dim_model + self.mlp = nn.Sequential( + nn.Linear(in_features=dim_model, out_features=dim_mlp, bias=bias), + build_activation(activation), + nn.Dropout(dropout), + nn.Linear(in_features=dim_mlp, out_features=dim_model, bias=bias), + nn.Dropout(dropout), + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return self.mlp(inputs) diff --git a/pkgs/xformers/components/input_projection.py b/pkgs/xformers/components/input_projection.py new file mode 100644 index 0000000..df984f0 --- /dev/null +++ b/pkgs/xformers/components/input_projection.py @@ -0,0 +1,99 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# CREDITS: Inspired by https://github.com/pytorch/text/blob/master/torchtext/nn/modules/multiheadattention.py +# and the MultiHeadAttention implementation from PyTorch + + +import logging +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch import nn + +logger = logging.getLogger("xformers") + + +@dataclass +class InputProjectionConfig: + in_features: int + out_features: int + bias: bool + + +class InputProjection(nn.Module): + """ + Handle all the input projections in one go, opportunistically fuse some operations. + """ + + def __init__( + self, + query_proj_params: InputProjectionConfig, + key_proj_params: Optional[InputProjectionConfig], + value_proj_params: Optional[InputProjectionConfig], + use_separate_proj_weight: bool = True, + ): + + super().__init__() + + self.out_features = query_proj_params.out_features + + # Each input gets a separate projection + self.q_proj = nn.Linear( + query_proj_params.in_features, + query_proj_params.out_features, + query_proj_params.bias, + ) + + if key_proj_params is not None: + self.k_proj = nn.Linear( + key_proj_params.in_features, + key_proj_params.out_features, + key_proj_params.bias, + ) + else: + logger.info( + "No Key projection parameters were passed, assuming that the weights" + + " are shared with the query projection" + ) + self.k_proj = self.q_proj + + if value_proj_params is not None: + self.v_proj = nn.Linear( + value_proj_params.in_features, + value_proj_params.out_features, + value_proj_params.bias, + ) + else: + logger.info( + "No Value projection parameters were passed, assuming that the weights" + + " are shared with the query projection" + ) + self.v_proj = self.q_proj + + if not use_separate_proj_weight: + # Compute optimization used at times, share the parameters in between Q/K/V + with torch.no_grad(): + self.k_proj.weight = self.q_proj.weight + self.v_proj.weight = self.q_proj.weight + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # One projection per input tensor + + # NOTE: Would it make sense to catch self attention + shared weights, to skip a projection step ? + + q, k, v = map( + lambda fn, x: fn(x), + [self.q_proj, self.k_proj, self.v_proj], + [query, key, value], + ) + + return q, k, v diff --git a/pkgs/xformers/components/multi_head_dispatch.py b/pkgs/xformers/components/multi_head_dispatch.py new file mode 100644 index 0000000..363fa21 --- /dev/null +++ b/pkgs/xformers/components/multi_head_dispatch.py @@ -0,0 +1,269 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import asdict, dataclass +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.nn.init import constant_ + +from xformers.components.attention import Attention +from xformers.components.input_projection import InputProjection, InputProjectionConfig +from xformers.components.positional_embedding import RotaryEmbedding + +logger = logging.getLogger("xformers") + + +@dataclass +class MultiHeadDispatchConfig: + dim_model: int + num_heads: int + attention: Attention + bias: bool + residual_dropout: float + dim_key: Optional[int] + dim_value: Optional[int] + in_proj_container: Optional[InputProjection] + use_separate_proj_weight: Optional[bool] + use_rotary_embeddings: Optional[bool] + out_proj: Optional[nn.Module] + + def __getitem__(self, item): + return getattr(self, item) + + +# Move head forward and fold into batch dim. dimensions become (B * nh, S, hs) +def _fold_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): + return t.view(B, S, H, Hs).transpose(1, 2).flatten(start_dim=0, end_dim=1) + + +# Move head forward and fold into batch dim. dimensions become (B, nh, S, hs) +def _split_heads(t: torch.Tensor, B: int, S: int, H: int, Hs: int): + return t.view(B, S, H, Hs).transpose(1, 2) + + +class MultiHeadDispatch(nn.Module): + """ + A multi-head masked self-attention dispatch mechanism, with a projection at the end, + following the architecture proposed in `Attention is all you need`_, Vaswani et al. + + The actual attention mechanism can vary, as well as the projections. + This can be used to wrap the proposed attention mechanisms and make them multi-head aware, + but it is optional. + + Args: + dim_model: The model/embedding dimension + num_heads: The number of heads being used + attention: The attention mechanism (needs to be registered to the xformers library) + bias: Whether to use bias for the projections : (Q, K, V, Output) + residual_dropout: Amount of dropout on the residual path + use_separate_proj_weight: Use different weights for the Q, K, V projections + dim_key: Optionally use a different dimension for the key + dim_value: Optionally use a different dimension for the value + in_proj_container: Optionally provide the input projection module + use_rotary_embeddings: Use rotary embeddings + out_proj: Optionally provide the output projection module + + + .. _`Attention is all you need`: https://arxiv.org/abs/1706.03762v5 + """ + + def __init__( + self, + dim_model: int, + num_heads: int, + attention: Attention, + bias: Tuple[bool, bool, bool, bool] = (True, True, True, True), + residual_dropout: float = 0.0, + use_separate_proj_weight: bool = True, + dim_key: Optional[int] = None, + dim_value: Optional[int] = None, + in_proj_container: Optional[InputProjection] = None, + use_rotary_embeddings: Optional[bool] = False, + out_proj: Optional[nn.Module] = None, + *args, + **kwargs, + ): + super().__init__() + + if isinstance(bias, bool): + logger.warning( + "Single bias value provided for the MHA projections." + + f" Assuming the same parameter ({bias}) is to be used everywhere" + ) + bias = (bias, bias, bias, bias) + + assert ( + dim_model % num_heads == 0 + ) # static preset for now, each head works on 1/d the embeddings, could be relaxed + assert num_heads > 0 + + # Popular default is that all latent dimensions are the same + dim_key, dim_value = map(lambda x: x if x else dim_model, (dim_key, dim_value)) + + self.num_heads = num_heads + self.dim_key_head = dim_key // num_heads + self.dim_value_head = dim_value // num_heads + self.dim_model = dim_model + self.attention = attention + + # key, query, value projections for all heads + # critical options are + # - are we sharing weights ? + # - are we adding biases ? + if attention.requires_input_projection: + self.in_proj_container = ( + in_proj_container + if in_proj_container is not None + else InputProjection( + query_proj_params=InputProjectionConfig( + dim_model, dim_key, bias=bias[0] + ), + key_proj_params=InputProjectionConfig( + dim_model, dim_key, bias=bias[1] + ), + value_proj_params=InputProjectionConfig( + dim_model, dim_value, bias=bias[2] + ), + use_separate_proj_weight=use_separate_proj_weight, + ) + ) + + # Optional rotary embeddings + self.rotary_embeddings = ( + RotaryEmbedding(self.dim_key_head) if use_rotary_embeddings else None + ) + + # Regularization + self.resid_drop = nn.Dropout(residual_dropout, inplace=False) + + # Output projection + self.proj = ( + out_proj if out_proj else nn.Linear(dim_model, dim_model, bias=bias[3]) + ) + if isinstance(self.proj, nn.Linear) and self.proj.bias is not None: + constant_(self.proj.bias, 0.0) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + att_mask: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Expected input dimensions are [batch size, sequence length, embed dim] + Output dimensions are [batch size, sequence length, embed dim] + """ + + if key is None: + key = query + if value is None: + value = query + + if query.shape[0] != key.shape[0] or query.shape[0] != value.shape[0]: + max_batch = max((query.shape[0], key.shape[0], value.shape[0])) + query, key, value = map( + lambda x: x.expand(max_batch, -1, -1), [query, key, value] + ) + + B, S_Q, _ = query.size() # Batch x Sequence x Embedding (latent) + _, S_K, _ = key.size() # K, Q's sequence length could differ + + # Catch different query and key length but a causal attention + if S_Q != S_K: + assert ( + not self.attention.requires_same_k_q_dimensions + ), "This attention mechanism requires query and key to have the same sequence (context) lengths" + + if hasattr(self.attention, "causal"): + assert not self.attention.causal, ( + "Causal attention is not supported when key and query have different sequence lengths.\n" + + "In that case causality is ill-determined. Please pad your sequences accordingly" + ) + + kw_mask_args = {} + if att_mask is not None: + assert ( + self.attention.supports_attention_mask + ), "This attention does not support attention masks" + kw_mask_args["att_mask"] = att_mask + + if key_padding_mask is not None: + assert ( + self.attention.supports_key_padding_mask + ), "This attention does not support key padding masks" + kw_mask_args["key_padding_mask"] = key_padding_mask + + if self.attention.requires_skip_multi_head: + return self.attention(query, key, value, **kw_mask_args) + + # Calculate query, key, values for all heads in batch + if self.attention.requires_input_projection: + q, k, v = self.in_proj_container(query=query, key=key, value=value) + else: + k, q, v = key, query, value + + # Check the dimensions properly + def check(t, name): + assert ( + t.shape[2] % self.num_heads == 0 + ), f"the {name} embeddings need to be divisible by the number of heads" + + check(q, "projected query") + check(v, "projected value") + check(k, "projected key") + + # Optional: rotary embedding, add relative positioning information + if self.rotary_embeddings: + # rotary requires the head dimension + q = _split_heads(q, B, S_Q, self.num_heads, self.dim_key_head) + k = _split_heads(k, B, S_K, self.num_heads, self.dim_key_head) + v = _split_heads(v, B, S_K, self.num_heads, self.dim_value_head) + + q, k = self.rotary_embeddings(q=q, k=k) + + if not self.attention.requires_head_dimension: + q, k, v = q.flatten(0, 1), k.flatten(0, 1), v.flatten(0, 1) + + else: + # Reshape k/q/v to either expose the heads, or fold the head dimension into the batch + reshape_fn = ( + _split_heads if self.attention.requires_head_dimension else _fold_heads + ) + + q = reshape_fn(q, B, S_Q, self.num_heads, self.dim_key_head) + k = reshape_fn(k, B, S_K, self.num_heads, self.dim_key_head) + v = reshape_fn(v, B, S_K, self.num_heads, self.dim_value_head) + + # Self-attend + y = self.attention(q, k, v, **kw_mask_args) + + # Re-assemble all head outputs side by side + y = ( + y.view(B, self.num_heads, S_Q, self.dim_value_head) + .transpose(1, 2) + .flatten(start_dim=2, end_dim=3) + ) + + # Output projection, dropout and good to go + y = self.resid_drop(self.proj(y)) + + # Return the same sequence size as the input + return y + + @classmethod + def from_config(cls, config: MultiHeadDispatchConfig): + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + + return cls(**fields) diff --git a/pkgs/xformers/components/patch_embedding.py b/pkgs/xformers/components/patch_embedding.py new file mode 100644 index 0000000..23bb601 --- /dev/null +++ b/pkgs/xformers/components/patch_embedding.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from enum import Enum + +import torch + + +class PoolType(str, Enum): + Conv2D = "CONV_2D" + # ... + # TODO: Support more cases ? + + +@dataclass +class PatchEmbeddingConfig: + """ + The configuration for the patch embedding layer, which takes the raw token passed in + and returns a pooled representation along a given embedding dimension. + + This typically trades the spatial (context length) representation with the embedding size + + This is canonicaly used by ViT, but other papers (like MetaFormer or other hierarchical transformers) + propose a more general use case for this + """ + + in_channels: int + out_channels: int + kernel_size: int + stride: int + padding: int = 0 + pool_type: PoolType = PoolType.Conv2D + + +class ConditionalReshape(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + if x.ndim == 3: + B, HW, C = x.shape + # NOTE: We're assuming a square sample here + H = int(math.sqrt(HW)) + assert H * H == HW, f"{H, HW}" + x = x.transpose(1, 2).reshape(B, C, H, H) + + return x + + +class PatchToSequence(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C + + +def build_patch_embedding(config: PatchEmbeddingConfig): + if not isinstance(config, PatchEmbeddingConfig): + config = PatchEmbeddingConfig(**config) + + if config.pool_type == PoolType.Conv2D: + pool = torch.nn.Conv2d( + config.in_channels, + config.out_channels, + kernel_size=config.kernel_size, + stride=config.stride, + padding=config.padding, + ) + else: + raise NotImplementedError + + # The patch embedding supposes that the input really is 2D in essence + # If this block is in the middle of a stack, we need to reshape + return torch.nn.Sequential(ConditionalReshape(), pool, PatchToSequence()) diff --git a/pkgs/xformers/components/positional_embedding/__init__.py b/pkgs/xformers/components/positional_embedding/__init__.py new file mode 100644 index 0000000..0f7f02c --- /dev/null +++ b/pkgs/xformers/components/positional_embedding/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from pathlib import Path +from typing import Any, Callable, Dict, Set, Union + +from xformers.utils import ( + generate_matching_config, + get_registry_decorator, + import_all_modules, +) + +from .base import PositionEmbedding, PositionEmbeddingConfig # noqa + +# CREDITS: Classy Vision registry mechanism + +POSITION_EMBEDDING_REGISTRY: Dict[str, Any] = {} +POSITION_EMBEDDING_CLASS_NAMES: Set[str] = set() + + +def build_positional_embedding(config: Union[Dict[str, Any], PositionEmbeddingConfig]): + """Builds a position encoding from a config. + + This assumes a 'name' key in the config which is used to determine what + attention class to instantiate. For instance, a config `{"name": "my_position_encoding", + "foo": "bar"}` will find a class that was registered as "my_position_encoding" + (see :func:`register_positional_embedding`) and call .from_config on it.""" + + if not isinstance(config, PositionEmbeddingConfig): + config_instance = generate_matching_config( + config, POSITION_EMBEDDING_REGISTRY[config["name"]].config + ) + else: + config_instance = config + + return POSITION_EMBEDDING_REGISTRY[config_instance.name].constructor.from_config( + config_instance + ) + + +"""Registers a PositionEncoding subclass. + + This decorator allows xFormers to instantiate a subclass of PositionEncoding + from a configuration file, even if the class itself is not part of the + xFormers framework. To use it, apply this decorator to a `PositionEncoding` + subclass, like this: + + .. code-block:: python + + @dataclass + class MyConfig: + ... + + @register_positional_embedding('my_encoding', MyConfig) + class MyEncoding(PositionEncoding): + ... + + To instantiate a position encoding from a configuration file, see :func:`build_positional_embedding`.""" +register_positional_embedding: Callable[ + [str, Any], Callable[[Any], Any] +] = get_registry_decorator( + POSITION_EMBEDDING_REGISTRY, + POSITION_EMBEDDING_CLASS_NAMES, + PositionEmbedding, + PositionEmbeddingConfig, +) + + +from .rotary import RotaryEmbedding # noqa +from .sine import SinePositionalEmbedding # type: ignore # noqa +from .vocab import VocabEmbedding # noqa + +__all__ = [ + "RotaryEmbedding", + "SinePositionalEmbedding", + "VocabEmbedding", + "build_positional_embedding", + "register_positional_embedding", +] + +# automatically import any Python files in the directory +import_all_modules( + str(Path(__file__).parent), "xformers.components.positional_embedding" +) diff --git a/pkgs/xformers/components/positional_embedding/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/components/positional_embedding/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..49f275d Binary files /dev/null and b/pkgs/xformers/components/positional_embedding/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/components/positional_embedding/__pycache__/base.cpython-310.pyc b/pkgs/xformers/components/positional_embedding/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..7d20801 Binary files /dev/null and b/pkgs/xformers/components/positional_embedding/__pycache__/base.cpython-310.pyc differ diff --git a/pkgs/xformers/components/positional_embedding/__pycache__/param.cpython-310.pyc b/pkgs/xformers/components/positional_embedding/__pycache__/param.cpython-310.pyc new file mode 100644 index 0000000..6284dd3 Binary files /dev/null and b/pkgs/xformers/components/positional_embedding/__pycache__/param.cpython-310.pyc differ diff --git a/pkgs/xformers/components/positional_embedding/__pycache__/rotary.cpython-310.pyc b/pkgs/xformers/components/positional_embedding/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000..069ae9a Binary files /dev/null and b/pkgs/xformers/components/positional_embedding/__pycache__/rotary.cpython-310.pyc differ diff --git a/pkgs/xformers/components/positional_embedding/__pycache__/sine.cpython-310.pyc b/pkgs/xformers/components/positional_embedding/__pycache__/sine.cpython-310.pyc new file mode 100644 index 0000000..0ee5197 Binary files /dev/null and b/pkgs/xformers/components/positional_embedding/__pycache__/sine.cpython-310.pyc differ diff --git a/pkgs/xformers/components/positional_embedding/__pycache__/vocab.cpython-310.pyc b/pkgs/xformers/components/positional_embedding/__pycache__/vocab.cpython-310.pyc new file mode 100644 index 0000000..7037870 Binary files /dev/null and b/pkgs/xformers/components/positional_embedding/__pycache__/vocab.cpython-310.pyc differ diff --git a/pkgs/xformers/components/positional_embedding/base.py b/pkgs/xformers/components/positional_embedding/base.py new file mode 100644 index 0000000..c35f1a2 --- /dev/null +++ b/pkgs/xformers/components/positional_embedding/base.py @@ -0,0 +1,35 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from abc import ABCMeta, abstractmethod +from dataclasses import asdict, dataclass +from typing import Type, TypeVar + +import torch.nn as nn + +Self = TypeVar("Self", bound="PositionEmbedding") + + +@dataclass +class PositionEmbeddingConfig: + name: str + dim_model: int + seq_len: int + + +class PositionEmbedding(nn.Module, metaclass=ABCMeta): + @abstractmethod + def __init__(self, *args, **kwargs) -> None: + super().__init__() + + @classmethod + def from_config(cls: Type[Self], config: PositionEmbeddingConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + # Skip all Nones so that default values are used + fields = {k: v for k, v in fields.items() if v is not None} + return cls(**fields) diff --git a/pkgs/xformers/components/positional_embedding/param.py b/pkgs/xformers/components/positional_embedding/param.py new file mode 100644 index 0000000..bc96cf6 --- /dev/null +++ b/pkgs/xformers/components/positional_embedding/param.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass + +import torch + +from xformers.components.positional_embedding import ( + PositionEmbedding, + PositionEmbeddingConfig, + register_positional_embedding, +) + + +@dataclass +class LearnablePositionalEmbeddingConfig(PositionEmbeddingConfig): + name: str + seq_len: int + dim_model: int + add_class_token: bool + + +@register_positional_embedding("learnable", LearnablePositionalEmbeddingConfig) +class LearnablePositionalEmbedding(PositionEmbedding): + def __init__( + self, seq_len: int, dim_model: int, add_class_token: bool = False, *_, **__ + ): + super().__init__() + + # 0.02 is BERT initialization + self.pos_emb = torch.nn.Parameter( + torch.randn(1, seq_len + int(add_class_token), dim_model) * 0.02 + ) + + self.class_token = ( + torch.nn.Parameter(torch.zeros(dim_model)) if add_class_token else None + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.class_token is not None: + # Prepend class token + clf_token = ( + torch.ones(x.shape[0], 1, self.pos_emb.shape[-1], device=x.device) + * self.class_token + ) + x = torch.cat([clf_token, x], dim=1) + + if x.ndim == 2: + x = x.unsqueeze(-1) + + return x + self.pos_emb diff --git a/pkgs/xformers/components/positional_embedding/rotary.py b/pkgs/xformers/components/positional_embedding/rotary.py new file mode 100644 index 0000000..551089b --- /dev/null +++ b/pkgs/xformers/components/positional_embedding/rotary.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: This implementation is inspired by GPT-NeoX https://github.com/EleutherAI/gpt-neox +# NOTE: Almost the same right now, moving parts to Triton is the next step + +from typing import Tuple + +import torch + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +@torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim_model: int, *_, **__): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim_model, 2).float() / dim_model)) + self.register_buffer("inv_freq", inv_freq) + + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=1): + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + k, seq_dimension=-2 + ) + + return ( + apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached), + apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached), + ) diff --git a/pkgs/xformers/components/positional_embedding/sine.py b/pkgs/xformers/components/positional_embedding/sine.py new file mode 100644 index 0000000..321920c --- /dev/null +++ b/pkgs/xformers/components/positional_embedding/sine.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# Silence Mypy errors in this file. +# type: ignore + +import math + +import torch + +from xformers.components.positional_embedding import ( + PositionEmbedding, + PositionEmbeddingConfig, + register_positional_embedding, +) + + +@register_positional_embedding("sine", PositionEmbeddingConfig) +class SinePositionalEmbedding(PositionEmbedding): + def __init__(self, dim_model: int, *args, **kwargs): + super().__init__() + self.dim_model = dim_model + + def forward(self, x: torch.Tensor) -> torch.Tensor: + seq_len = x.shape[1] + pos = ( + torch.arange(0, seq_len, device=x.device, dtype=torch.float32) + .unsqueeze(1) + .repeat(1, self.dim_model) + ) + dim = ( + torch.arange(0, self.dim_model, device=x.device, dtype=torch.float32) + .unsqueeze(0) + .repeat(seq_len, 1) + ) + div = torch.exp(-math.log(10000) * (2 * (dim // 2) / self.dim_model)) + pos *= div + pos[:, 0::2] = torch.sin(pos[:, 0::2]) + pos[:, 1::2] = torch.cos(pos[:, 1::2]) + + output = x.unsqueeze(-1) if x.ndim == 2 else x + + return output + pos.unsqueeze(0) diff --git a/pkgs/xformers/components/positional_embedding/vocab.py b/pkgs/xformers/components/positional_embedding/vocab.py new file mode 100644 index 0000000..dd18777 --- /dev/null +++ b/pkgs/xformers/components/positional_embedding/vocab.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn as nn + +from xformers.components.positional_embedding import ( + PositionEmbedding, + PositionEmbeddingConfig, + register_positional_embedding, +) + + +@dataclass +class VocabEmbeddingConfig(PositionEmbeddingConfig): + vocab_size: int + dropout: float + + +@register_positional_embedding("vocab", VocabEmbeddingConfig) +class VocabEmbedding(PositionEmbedding): + def __init__( + self, + dim_model: int, + seq_len: int, + vocab_size: int, + dropout: float = 0.0, + *args, + **kwargs + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.position_embeddings = nn.Embedding(seq_len, self.dim_model) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + self.position_ids: Optional[torch.Tensor] = None + + self.init_weights() + + def init_weights(self, gain: float = 1.0): + torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain) + torch.nn.init.normal_(self.word_embeddings.weight, std=0.02 * gain) + + def forward(self, x: torch.Tensor): + position_ids = torch.arange(x.shape[1], dtype=torch.long, device=x.device)[ + None, : + ].repeat(x.shape[0], 1) + + X_token = self.word_embeddings(x) + X_pos = self.position_embeddings(position_ids) + + X = X_token + X_pos + X = self.dropout(X) + + return X diff --git a/pkgs/xformers/components/residual.py b/pkgs/xformers/components/residual.py new file mode 100644 index 0000000..d9e4dde --- /dev/null +++ b/pkgs/xformers/components/residual.py @@ -0,0 +1,206 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from enum import Enum +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from xformers import _is_triton_available + +if _is_triton_available(): + from xformers.triton.layer_norm import FusedLayerNorm + +from collections import namedtuple + + +class ResidualNormStyle(str, Enum): + """Support different residual path and norm styles. + See "On Layer Normalization in the Transformer Architecture", + Xiong et al., https://arxiv.org/pdf/2002.04745v1.pdf + """ + + Pre = "pre" + Post = "post" + DeepNorm = "deepnorm" + + +class NormalizationType(str, Enum): + LayerNorm = "layernorm" + Skip = "skip" + # TODO: BatchNorm = "batchnorm" + # TODO: GroupNorm = "groupnorm" + + +def get_normalization_layer(normalization_type: NormalizationType): + class Skip(nn.Module): + def __init__(self, *_, **__) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, **_): + return x + + return { + NormalizationType.LayerNorm: nn.LayerNorm, + NormalizationType.Skip: Skip, + }[normalization_type] + + +class RequiresWrappedInputs: + """Used to mark, through inheritance, + the fact that this class will require inputs to be passed as a single list""" + + pass + + +# CREDITS: the following is inspired by FastAI's Transformer implementation +class Residual(nn.Module, RequiresWrappedInputs): + """ + Object-oriented handling of the residual path + + This supports scaling of the residual path, as proposed by DeepNet_ + .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf + + .. Note: the wrapped layers must accept all the inputs as a single list + """ + + def __init__(self, layer: nn.Module, scale: Optional[float] = None): + super().__init__() + self.layer = layer + self.scale = scale + + # PreNorm and PostNorm require all the tensors to be passed as a list + self.wrap_inputs = isinstance(layer, RequiresWrappedInputs) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + if self.scale is not None: + residue = inputs[0] * self.scale + else: + residue = inputs[0] + + if self.wrap_inputs: + return residue + self.layer(inputs=inputs, **kwargs) + + else: + return residue + self.layer(*inputs, **kwargs) + + +class PreNorm(nn.Module, RequiresWrappedInputs): + """Adds a normalization before computing attention + + ..Note: If a list of inputs is passed, all of them get normalized""" + + def __init__( + self, + d_norm: int, + sublayer: nn.Module, + normalization: NormalizationType, + use_triton: bool = True, + ): + + super().__init__() + if ( + _is_triton_available() + and use_triton + and normalization == NormalizationType.LayerNorm + ): + self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm) + else: + self.norm = get_normalization_layer(normalization)(d_norm) + + self.sublayer = sublayer + self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + assert len(inputs) > 0 + + # Perf improvement: if the inputs are all the same, only norm once + ids = [id(x) for x in inputs] + if ids.count(ids[0]) == len(ids): + # The same tensor is passed multiple times + x_norm = self.norm(inputs[0]) + inputs_normed = [x_norm for _ in inputs] + else: + # The inputs differ, norm them all + inputs_normed = [self.norm(x_) for x_ in inputs] + + if self.wrap_inputs: + return self.sublayer(inputs=inputs_normed, **kwargs) + else: + return self.sublayer(*inputs_normed, **kwargs) + + +class PostNorm(nn.Module, RequiresWrappedInputs): + """Adds LayerNorm after computing attention""" + + def __init__( + self, + d_norm: int, + sublayer: nn.Module, + normalization: NormalizationType, + use_triton: bool = True, + ): + super().__init__() + if ( + _is_triton_available() + and use_triton + and normalization == NormalizationType.LayerNorm + ): + self.norm: Union[nn.LayerNorm, FusedLayerNorm] = FusedLayerNorm(d_norm) + else: + self.norm = get_normalization_layer(normalization)(d_norm) + + self.sublayer = sublayer + self.wrap_inputs = isinstance(sublayer, RequiresWrappedInputs) + + def forward(self, inputs: List[torch.Tensor], **kwargs): + if self.wrap_inputs: + x = self.sublayer(inputs=inputs, **kwargs) + else: + x = self.sublayer(*inputs, **kwargs) + return self.norm(x) + + +DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"]) + + +def get_deepnorm_coefficients( + encoder_layers: int, decoder_layers: int +) -> Tuple[Optional[DeepNormCoefficients], Optional[DeepNormCoefficients]]: + """ + See DeepNet_. + + Returns alpha and beta depending on the number of encoder and decoder layers, + first tuple is for the encoder and second for the decoder + + .. _DeepNet: https://arxiv.org/pdf/2203.00555v1.pdf + """ + + N = encoder_layers + M = decoder_layers + + if decoder_layers == 0: + # Encoder only + return ( + DeepNormCoefficients(alpha=(2 * N) ** 0.25, beta=(8 * N) ** -0.25), + None, + ) + + elif encoder_layers == 0: + # Decoder only + return None, DeepNormCoefficients(alpha=(2 * M) ** 0.25, beta=(8 * M) ** -0.25) + else: + # Encoder/decoder + encoder_coeffs = DeepNormCoefficients( + alpha=0.81 * ((N**4) * M) ** 0.0625, beta=0.87 * ((N**4) * M) ** -0.0625 + ) + + decoder_coeffs = DeepNormCoefficients( + alpha=(3 * M) ** 0.25, beta=(12 * M) ** -0.25 + ) + + return (encoder_coeffs, decoder_coeffs) diff --git a/pkgs/xformers/components/reversible.py b/pkgs/xformers/components/reversible.py new file mode 100644 index 0000000..ae70ec0 --- /dev/null +++ b/pkgs/xformers/components/reversible.py @@ -0,0 +1,157 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List + +import torch +import torch.nn as nn +from torch.autograd.function import Function +from torch.utils.checkpoint import get_device_states, set_device_states + +from xformers.components import RequiresWrappedInputs + +# CREDITS: Code adapted from +# https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reversible.py +# https://github.com/RobinBruegger/RevTorch/blob/master/revtorch/revtorch.py, +# https://pytorch.org/docs/stable/_modules/torch/utils/checkpoint.html + + +# pyre-fixme[13]: `cpu_state` is not initialized in the constructor. +class Deterministic(nn.Module): + def __init__(self, net: nn.Module): + super().__init__() + self.net = net + self.cpu_state: torch.Tensor = torch.get_rng_state() + self.cuda_in_fwd: bool = False + self.gpu_devices: List[int] = [] + self.gpu_states: List[torch.Tensor] = [] + self.wrap_inputs = isinstance(net, RequiresWrappedInputs) + + def record_rng(self, *args): + self.cpu_state = torch.get_rng_state() + if torch.cuda._initialized: + self.cuda_in_fwd = True + self.gpu_devices, self.gpu_states = get_device_states(*args) + + def forward(self, *args, record_rng: bool = False, set_rng: bool = False, **kwargs): + if record_rng: + self.record_rng(*args) + + if not set_rng: + # Normal FW run + if self.wrap_inputs: + return self.net(inputs=args, **kwargs) + else: + return self.net(*args, **kwargs) + + else: # pragma: no cover # this is called in the backward pass, not picked up + # This is analogous to checkpointing, reset the original random state + rng_devices: List[int] = [] + if self.cuda_in_fwd: + rng_devices = self.gpu_devices + + with torch.random.fork_rng(devices=rng_devices, enabled=True): + torch.set_rng_state(self.cpu_state) + if self.cuda_in_fwd: + set_device_states(self.gpu_devices, self.gpu_states) + + if self.wrap_inputs: + return self.net(inputs=args, **kwargs) + else: + return self.net(*args, **kwargs) + + +class ReversibleBlock(nn.Module): + def __init__(self, f: nn.Module, g: nn.Module, split_dim: int = -1): + super().__init__() + self.f = Deterministic(f) + self.g = Deterministic(g) + self.split_dim = split_dim + + def forward(self, x: torch.Tensor, f_args={}, g_args={}): + x1, x2 = torch.chunk(x, 2, dim=-1) + y1, y2 = None, None + + with torch.no_grad(): + y1 = x1 + self.f(x2, record_rng=self.training, **f_args) + y2 = x2 + self.g(y1, record_rng=self.training, **g_args) + + return torch.cat([y1, y2], dim=self.split_dim) + + def backward_pass( + self, y: torch.Tensor, dy: torch.Tensor, f_args={}, g_args={} + ): # pragma: no cover # this is covered, but called directly from C++ + y1, y2 = torch.chunk(y, 2, dim=self.split_dim) + del y + + dy1, dy2 = torch.chunk(dy, 2, dim=self.split_dim) + del dy + + with torch.enable_grad(): + y1.requires_grad = True + gy1 = self.g(y1, set_rng=True, **g_args) + torch.autograd.backward(gy1, dy2) + + with torch.no_grad(): + x2 = y2 - gy1 + del y2, gy1 + + dx1 = dy1 + y1.grad + del dy1 + y1.grad = None + + with torch.enable_grad(): + x2.requires_grad = True + fx2 = self.f(x2, set_rng=True, **f_args) + torch.autograd.backward(fx2, dx1) + + with torch.no_grad(): + x1 = y1 - fx2 + del y1, fx2 + + dx2 = dy2 + x2.grad + del dy2 + x2.grad = None + + x = torch.cat([x1, x2.detach()], dim=self.split_dim) + dx = torch.cat([dx1, dx2], dim=self.split_dim) + + return x, dx + + +class _ReversibleFunction(Function): + @staticmethod + def forward(ctx, x, blocks, kwargs): + ctx.kwargs = kwargs + for block in blocks: + x = block(x, **kwargs) + ctx.y = x.detach() + ctx.blocks = blocks + return x + + @staticmethod + def backward( + ctx, dy + ): # pragma: no cover # this is covered, but called directly from C++ + y = ctx.y + kwargs = ctx.kwargs + for block in ctx.blocks[::-1]: + y, dy = block.backward_pass(y, dy, **kwargs) + return dy, None, None + + +class ReversibleSequence(nn.Module): + def __init__(self, blocks: nn.ModuleList): + super().__init__() + + # pyre-fixme[23]: Unable to unpack `torch.nn.Module` into 2 values. + self.blocks = nn.ModuleList([ReversibleBlock(f, g) for f, g in blocks]) + + def forward(self, x, arg_route=(True, False), **kwargs): + f_args, g_args = map(lambda route: kwargs if route else {}, arg_route) + block_kwargs = {"f_args": f_args, "g_args": g_args} + + return _ReversibleFunction.apply(x, self.blocks, block_kwargs) diff --git a/pkgs/xformers/components/simplicial_embedding.py b/pkgs/xformers/components/simplicial_embedding.py new file mode 100644 index 0000000..a420905 --- /dev/null +++ b/pkgs/xformers/components/simplicial_embedding.py @@ -0,0 +1,73 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import asdict, dataclass +from typing import Optional, Type, TypeVar + +import torch + +from xformers import _is_triton_available + +Self = TypeVar("Self", bound="SimplicialEmbedding") + + +@dataclass +class SimplicialEmbeddingConfig: + L: int + temperature: float + + +class SimplicialEmbedding(torch.nn.Module): + """ + An implementation of the "Simplicial Embeddings"_, as proposed by Lavoie et. al + + Arguments: + - L: the number of embedding chunks + - temperature: optional scaling parameter for the softmax operation. + A small (<1.) temperature will lead to a sparse representation (up to one-hot), + while a large (>1.) temperature will make the vector more uniform + + _"Simplicial Embeddings": https://arxiv.org/pdf/2204.00616.pdf + """ + + def __init__(self, L: int, temperature: Optional[float] = None) -> None: + super().__init__() + self.L = L + self.temperature = temperature + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.shape[-1] % self.L == 0 + ), f"The embedding dimension {x.shape[-1]} is not divisible by the chosen L parameter {self.L}" + + # Seperate the input tensor into V chunks + B, C, E = x.shape + V = E // self.L + + Vs = x.reshape(B, C, self.L, V) + + # Softmax normalize them, with the proposed temperature + # This is done over the last dimension, so only within Vs + if self.temperature is not None: + Vs /= self.temperature + + if _is_triton_available(): + from xformers.triton.softmax import softmax as triton_softmax + + Vs = triton_softmax( + Vs, mask=None, causal=False + ) # the softmax is on the last dimension + else: + Vs = torch.nn.functional.softmax(Vs, dim=-1) + + # Concatenate back and return + return Vs.reshape(B, C, E) + + @classmethod + def from_config(cls: Type[Self], config: SimplicialEmbeddingConfig) -> Self: + # Generate the class inputs from the config + fields = asdict(config) + + return cls(**fields) diff --git a/pkgs/xformers/cpp_lib.json b/pkgs/xformers/cpp_lib.json new file mode 100644 index 0000000..0266eba --- /dev/null +++ b/pkgs/xformers/cpp_lib.json @@ -0,0 +1 @@ +{"version": {"cuda": 1002, "torch": "2.1.1", "python": "3.10.12", "flash": "daily_2024-01-11_1043-13-gca11361"}, "env": {"TORCH_CUDA_ARCH_LIST": null, "XFORMERS_BUILD_TYPE": "release", "XFORMERS_ENABLE_DEBUG_ASSERTIONS": "0", "NVCC_FLAGS": null, "XFORMERS_PACKAGE_FROM": null}} \ No newline at end of file diff --git a/pkgs/xformers/factory/__init__.py b/pkgs/xformers/factory/__init__.py new file mode 100644 index 0000000..882ca9e --- /dev/null +++ b/pkgs/xformers/factory/__init__.py @@ -0,0 +1,11 @@ +from xformers.components import MultiHeadDispatchConfig # noqa +from xformers.components.attention import AttentionConfig # noqa +from xformers.components.feedforward import FeedforwardConfig # noqa +from xformers.components.positional_embedding import PositionEmbeddingConfig # noqa + +from .block_factory import xFormerDecoderBlock # noqa +from .block_factory import xFormerDecoderConfig # noqa +from .block_factory import xFormerEncoderBlock # noqa +from .block_factory import xFormerEncoderConfig # noqa +from .model_factory import xFormer, xFormerConfig # noqa +from .weight_init import xFormerWeightInit # noqa diff --git a/pkgs/xformers/factory/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/factory/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..4eac6d2 Binary files /dev/null and b/pkgs/xformers/factory/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/factory/__pycache__/block_configs.cpython-310.pyc b/pkgs/xformers/factory/__pycache__/block_configs.cpython-310.pyc new file mode 100644 index 0000000..f823e42 Binary files /dev/null and b/pkgs/xformers/factory/__pycache__/block_configs.cpython-310.pyc differ diff --git a/pkgs/xformers/factory/__pycache__/block_factory.cpython-310.pyc b/pkgs/xformers/factory/__pycache__/block_factory.cpython-310.pyc new file mode 100644 index 0000000..a0d0186 Binary files /dev/null and b/pkgs/xformers/factory/__pycache__/block_factory.cpython-310.pyc differ diff --git a/pkgs/xformers/factory/__pycache__/hydra_helper.cpython-310.pyc b/pkgs/xformers/factory/__pycache__/hydra_helper.cpython-310.pyc new file mode 100644 index 0000000..137535f Binary files /dev/null and b/pkgs/xformers/factory/__pycache__/hydra_helper.cpython-310.pyc differ diff --git a/pkgs/xformers/factory/__pycache__/model_factory.cpython-310.pyc b/pkgs/xformers/factory/__pycache__/model_factory.cpython-310.pyc new file mode 100644 index 0000000..4db3bb0 Binary files /dev/null and b/pkgs/xformers/factory/__pycache__/model_factory.cpython-310.pyc differ diff --git a/pkgs/xformers/factory/__pycache__/weight_init.cpython-310.pyc b/pkgs/xformers/factory/__pycache__/weight_init.cpython-310.pyc new file mode 100644 index 0000000..14e5118 Binary files /dev/null and b/pkgs/xformers/factory/__pycache__/weight_init.cpython-310.pyc differ diff --git a/pkgs/xformers/factory/block_configs.py b/pkgs/xformers/factory/block_configs.py new file mode 100644 index 0000000..a628fa1 --- /dev/null +++ b/pkgs/xformers/factory/block_configs.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, Optional + +from xformers.components import NormalizationType, ResidualNormStyle +from xformers.components.feedforward import FEEDFORWARD_REGISTRY, FeedforwardConfig +from xformers.components.positional_embedding import ( + POSITION_EMBEDDING_REGISTRY, + PositionEmbeddingConfig, +) +from xformers.utils import generate_matching_config + + +class LayerPositionBitmask(int, Enum): + First = 0b01 + Last = 0b10 + Default = 0b11 + + +class LayerPosition: + """Bitmask to mark this layer as first, last, nothing or both""" + + def __init__(self): + self.bitmask = LayerPositionBitmask.Default + + def is_first(self): + return bool(self.bitmask & LayerPositionBitmask.First) + + def is_last(self): + return bool(self.bitmask & LayerPositionBitmask.Last) + + def mark_not_first(self): + self.bitmask &= ~LayerPositionBitmask.First + + def mark_not_last(self): + self.bitmask &= ~LayerPositionBitmask.Last + + +class BlockType(str, Enum): + Encoder = "encoder" + Decoder = "decoder" + + +@dataclass(init=False) # handle constructors explicitly to force type changes +class xFormerBlockConfig: + """ + The configuration structure to define a Transformer block. + This base class is applicable to both encoder and decoder definitions. + + This completely defines each of the blocks, for instance in terms of dimensions, + position encoding, pre or post layer norms or reversibility. + """ + + dim_model: int + feedforward_config: FeedforwardConfig + position_encoding_config: Optional[PositionEmbeddingConfig] + block_type: BlockType + residual_norm_style: ResidualNormStyle + normalization: NormalizationType + layer_position: LayerPosition + use_triton: bool + reversible: bool + num_layers: int + + def __init__( + self, + dim_model: int, + feedforward_config: Dict[str, Any], + position_encoding_config: Optional[Dict[str, Any]], + block_type: BlockType, + residual_norm_style: ResidualNormStyle = ResidualNormStyle("post"), + normalization: NormalizationType = NormalizationType.LayerNorm, + reversible: bool = False, + num_layers: int = 1, + layer_position: Optional[LayerPosition] = None, + ): + + self.dim_model = dim_model + self.block_type = block_type + self.residual_norm_style = residual_norm_style + self.reversible = reversible + self.num_layers = num_layers + self.normalization = normalization + + # Fill in possible gaps in the config for subparts of the block + self.feedforward_config = generate_matching_config( + feedforward_config, + FEEDFORWARD_REGISTRY[feedforward_config["name"]].config, + ) + + self.position_encoding_config = ( + generate_matching_config( + position_encoding_config, + POSITION_EMBEDDING_REGISTRY[position_encoding_config["name"]].config, + ) + if position_encoding_config is not None + else None + ) + + # Default is that this layer is the only one, so both first and last + if layer_position: + self.layer_position = layer_position + else: + self.layer_position = LayerPosition() + + +@dataclass(init=False) +class xFormerEncoderConfig(xFormerBlockConfig): + """ + The configuration structure for an encoder block + """ + + multi_head_config: Dict[str, Any] + use_triton: bool + simplicial_embeddings: Optional[Dict[str, Any]] + patch_embedding_config: Optional[Dict[str, Any]] + + def __init__( + self, + dim_model: int, + feedforward_config: Dict[str, Any], + multi_head_config: Dict[str, Any], + position_encoding_config: Optional[Dict[str, Any]] = None, + residual_norm_style: str = "post", + normalization: NormalizationType = NormalizationType.LayerNorm, + use_triton: bool = True, + simplicial_embeddings: Optional[Dict[str, Any]] = None, + patch_embedding_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + # Convenience, fill in duplicated fields + try: + if "dim_model" not in multi_head_config.keys(): + multi_head_config["dim_model"] = dim_model + + if "dim_model" not in feedforward_config.keys(): + feedforward_config["dim_model"] = dim_model + + if ( + position_encoding_config is not None + and "dim_model" not in position_encoding_config.keys() + ): + position_encoding_config["dim_model"] = dim_model + + if ( + patch_embedding_config is not None + and "out_channels" not in patch_embedding_config.keys() + ): + patch_embedding_config["out_channels"] = dim_model + + except AttributeError: + # A config instance was passed in, this is fine + pass + if "block_type" in kwargs: + assert kwargs["block_type"] == "encoder" + kwargs["block_type"] = BlockType("encoder") + super().__init__( + dim_model=dim_model, + feedforward_config=feedforward_config, + position_encoding_config=position_encoding_config, + residual_norm_style=ResidualNormStyle(residual_norm_style), + normalization=NormalizationType(normalization), + **kwargs, + ) + + self.multi_head_config = multi_head_config + self.use_triton = use_triton + self.simplicial_embeddings = simplicial_embeddings + self.patch_embedding_config = patch_embedding_config + + +@dataclass(init=False) +class xFormerDecoderConfig(xFormerBlockConfig): + """ + The configuration structure for a decoder block. + + This specifically defines the masked and cross attention mechanisms, + on top of the settings defining all blocks. + """ + + multi_head_config_masked: Dict[str, Any] # prior to encoder output + multi_head_config_cross: Dict[str, Any] # cross attention, takes encoder output + + def __init__( + self, + dim_model: int, + feedforward_config: Dict[str, Any], + multi_head_config_masked: Dict[str, Any], + multi_head_config_cross: Dict[str, Any], + position_encoding_config: Optional[Dict[str, Any]] = None, + residual_norm_style: str = "post", + normalization: NormalizationType = NormalizationType.LayerNorm, + use_triton: bool = True, + **kwargs, + ): + + # Convenience, fill in duplicated field + try: + if "dim_model" not in multi_head_config_masked.keys(): + multi_head_config_masked["dim_model"] = dim_model + + if "dim_model" not in multi_head_config_cross.keys(): + multi_head_config_cross["dim_model"] = dim_model + + if "dim_model" not in feedforward_config.keys(): + feedforward_config["dim_model"] = dim_model + + if ( + position_encoding_config is not None + and "dim_model" not in position_encoding_config.keys() + ): + position_encoding_config["dim_model"] = dim_model + except AttributeError: + # A config instance was passed in, this is fine + pass + if "block_type" in kwargs.keys(): + assert kwargs["block_type"] == "decoder" + kwargs["block_type"] = BlockType("decoder") + + super().__init__( + dim_model=dim_model, + feedforward_config=feedforward_config, + position_encoding_config=position_encoding_config, + residual_norm_style=ResidualNormStyle(residual_norm_style), + normalization=NormalizationType(normalization), + **kwargs, + ) + + self.multi_head_config_masked = multi_head_config_masked + self.multi_head_config_cross = multi_head_config_cross + self.use_triton = use_triton diff --git a/pkgs/xformers/factory/block_factory.py b/pkgs/xformers/factory/block_factory.py new file mode 100644 index 0000000..4868f9f --- /dev/null +++ b/pkgs/xformers/factory/block_factory.py @@ -0,0 +1,358 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import asdict +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from xformers._deprecation_warning import deprecated_function +from xformers.components import ( + PatchEmbeddingConfig, + PostNorm, + PreNorm, + Residual, + ResidualNormStyle, + build_multi_head_attention, + build_patch_embedding, +) +from xformers.components.attention import AttentionMask +from xformers.components.feedforward import build_feedforward +from xformers.components.positional_embedding import build_positional_embedding +from xformers.components.residual import get_deepnorm_coefficients +from xformers.components.simplicial_embedding import SimplicialEmbedding +from xformers.factory.block_configs import ( + NormalizationType, + xFormerDecoderConfig, + xFormerEncoderConfig, +) + +logger = logging.getLogger("xformers") + + +def _get_ln_factory( + d_model: int, + residual_norm_style: Optional[ResidualNormStyle], + use_triton: bool, + residual: bool, + normalization: NormalizationType = NormalizationType.LayerNorm, + residual_scale: float = 1.0, +): + """ + Handle all the supported residual path configurations. + + ..Note: we return the appropriate constructor, not an actual layer + """ + + def get_layer_wrapper( + d_model: int, + sublayer: nn.Module, + residual_norm_style: Optional[ResidualNormStyle], + residual: bool, + residual_scale: float, + ): + if residual: + if residual_norm_style == ResidualNormStyle.Pre: + return Residual( + layer=PreNorm(d_model, sublayer, normalization, use_triton), + scale=None, + ) + elif residual_norm_style == ResidualNormStyle.Post: + return PostNorm( + d_model, + Residual(layer=sublayer, scale=None), + normalization, + use_triton, + ) + elif residual_norm_style == ResidualNormStyle.DeepNorm: + return PostNorm( + d_model, + Residual(layer=sublayer, scale=residual_scale), + normalization, + use_triton=use_triton, + ) + else: + raise ValueError + + return ( + PreNorm(d_model, sublayer, normalization, use_triton) + if residual_norm_style == ResidualNormStyle.Pre + else PostNorm(d_model, sublayer, normalization, use_triton) + ) + + def ln_factory(sublayer: nn.Module): + return get_layer_wrapper( + d_model, sublayer, residual_norm_style, residual, residual_scale + ) + + return ln_factory + + +class xFormerEncoderBlock(torch.nn.Module): + r"""A vanilla Transformer Encoder block""" + + def __init__(self, config: xFormerEncoderConfig, **kwargs): + super().__init__() + deprecated_function(self) + + self.reversible_f = None + self.reversible_g = None + self.residual_norm_style = config.residual_norm_style + self.dim_model = config.dim_model + + # If this layer is the first one, and a pose encoding has been requested + if ( + config.position_encoding_config is not None + and config.layer_position.is_first() + ): + self.pose_encoding = build_positional_embedding( + asdict(config.position_encoding_config) + ) + + pos_encoding_dim = config.position_encoding_config.dim_model + mha_dim = config.multi_head_config["dim_model"] + + if pos_encoding_dim != mha_dim: + logger.warning( + f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa + ) + self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim) + else: + self.pose_encoding = None + + if config.residual_norm_style == ResidualNormStyle.DeepNorm: + # Just use the layer norm coefficient here, + # the init will be handled at the xformers level (knows about encoder and decoder blocks) + deep_norm_coefficients, _ = get_deepnorm_coefficients( + encoder_layers=config.num_layers, decoder_layers=0 + ) + assert deep_norm_coefficients is not None + residual_scale = deep_norm_coefficients.alpha + else: + residual_scale = 1.0 + + # mini helper, builds a normalization layer with the right Pre/Post config, residuals, and the right dimensions + ln_factory = _get_ln_factory( + config.dim_model, + config.residual_norm_style, + use_triton=config.use_triton, + residual=True, + residual_scale=residual_scale, + normalization=config.normalization, + ) + + mha = build_multi_head_attention(config.multi_head_config) + feedforward = build_feedforward(asdict(config.feedforward_config)) + + # Expose attention specific capabilities + self.supports_attention_mask = mha.attention.supports_attention_mask + self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions + self.causal = ( + mha.attention.causal if hasattr(mha.attention, "causal") else False + ) + + # Wrappers handle the different layer norm styles (pre- and post-) and the residual path + self.wrap_att = ln_factory(mha) + self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward) + if ( + config.residual_norm_style == ResidualNormStyle.Pre + and config.layer_position.is_last() + ): + self.wrap_ff = PostNorm( + config.dim_model, + self.wrap_ff, + normalization=config.normalization, + use_triton=config.use_triton, + ) + + # Simplicial embeddings are only used if specified, and on the last layer + self.simplicial_embedding: Optional[SimplicialEmbedding] = None + if config.simplicial_embeddings is not None and config.layer_position.is_last(): + self.simplicial_embedding = SimplicialEmbedding( + **config.simplicial_embeddings + ) + + # Optional patch embedding + self.patch_emb: Optional[nn.Module] = None + + if config.patch_embedding_config is not None: + self.patch_emb = build_patch_embedding( + PatchEmbeddingConfig(**config.patch_embedding_config) + ) + + @classmethod + def from_config(cls, config: xFormerEncoderConfig): + return cls(config) + + @staticmethod + def get_reversible_layer(config) -> Tuple[nn.Module, nn.Module]: + ln_factory = _get_ln_factory( + config.dim_model, + config.residual_norm_style, + residual=False, + use_triton=config.use_triton, + normalization=config.normalization, + ) + + mha = build_multi_head_attention(config.multi_head_config) + feedforward = build_feedforward(asdict(config.feedforward_config)) + + reversible_f = ln_factory(mha) + reversible_g = ln_factory(feedforward) + return reversible_f, reversible_g + + def forward( + self, + x: torch.Tensor, + att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + input_mask: Optional[torch.Tensor] = None, + ): + if self.patch_emb is not None: + x = self.patch_emb(x) + + if self.pose_encoding is not None: + x = self.pose_encoding(x) + + if hasattr(self, "embedding_projector"): + x = self.embedding_projector(x) + + # Handle the optional input masking, differs on Q, K, V + if input_mask is not None: + q = x + k = x * input_mask.unsqueeze(-1) + v = k + else: + q, k, v = x, x, x + + # Pre/Post norms and residual paths are already handled + x = self.wrap_att(inputs=[q, k, v], att_mask=att_mask) + x = self.wrap_ff(inputs=[x]) + + # Optional simplicial embeddings + if self.simplicial_embedding is not None: + x = self.simplicial_embedding(x) + + return x + + +class xFormerDecoderBlock(torch.nn.Module): + r"""A vanilla Transformer Decoder block + + ... note: this implementation is not (yet ?) reversible""" + + def __init__(self, config: xFormerDecoderConfig, **kwargs): + super().__init__() + deprecated_function(self) + + # If this layer is the first one, and a pose encoding as been requested + if ( + config.position_encoding_config is not None + and config.layer_position.is_first() + ): + self.pose_encoding = build_positional_embedding( + config.position_encoding_config + ) + + pos_encoding_dim = config.position_encoding_config.dim_model + mha_dim = config.multi_head_config_masked["dim_model"] + + if pos_encoding_dim != mha_dim: + + logger.warning( + f"The embedding dim and model dim do not match ({pos_encoding_dim} vs {mha_dim}), adding a projector layer." # noqa + ) + + self.embedding_projector = nn.Linear(pos_encoding_dim, mha_dim) + else: + self.pose_encoding = None + + if config.residual_norm_style == ResidualNormStyle.DeepNorm: + # Just use the layer norm coefficient here, + # the init will be handled at the xformers level (knows about encoder and decoder blocks) + _, deep_norm_coefficients = get_deepnorm_coefficients( + encoder_layers=0, decoder_layers=config.num_layers + ) + assert deep_norm_coefficients is not None + residual_scale = deep_norm_coefficients.alpha + else: + residual_scale = 1.0 + + # mini helper, builds a LayerNorm with the right Pre/Post config and the right dimensions + ln_factory = _get_ln_factory( + config.dim_model, + config.residual_norm_style, + use_triton=config.use_triton, + residual=True, + residual_scale=residual_scale, + normalization=config.normalization, + ) + + mha = build_multi_head_attention(config.multi_head_config_masked) + cross_mha = build_multi_head_attention(config.multi_head_config_cross) + feedforward = build_feedforward(config.feedforward_config) + + # Expose attention or feedforward specific capabilities + self.supports_attention_mask = mha.attention.supports_attention_mask + self.requires_same_k_q_dimensions = mha.attention.requires_same_k_q_dimensions + self.requires_squared_context_length = ( + feedforward.requires_squared_context + or mha.attention.requires_squared_context + ) + + self.causal_attention = ( + mha.attention.causal if hasattr(mha.attention, "causal") else False + ) + + # Wrappers handle the different layer norm styles (pre- and post-) and the residual path + self.wrap_att = ln_factory(mha) + self.wrap_cross = ln_factory(cross_mha) + self.wrap_ff: Union[Residual, PostNorm] = ln_factory(feedforward) + + if ( + config.residual_norm_style == ResidualNormStyle.Pre + and config.layer_position.is_last() + ): + self.wrap_ff = PostNorm( + config.dim_model, + self.wrap_ff, + normalization=NormalizationType.LayerNorm, + ) + + @classmethod + def from_config(cls, config: xFormerDecoderConfig): + return cls(config) + + def forward( + self, + target: torch.Tensor, + memory: torch.Tensor, + encoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + decoder_att_mask: Optional[Union[torch.Tensor, AttentionMask]] = None, + input_mask: Optional[torch.Tensor] = None, + ): + if self.pose_encoding is not None: + target = self.pose_encoding(target) + + if hasattr(self, "embedding_projector"): + target = self.embedding_projector(target) + + # Handle the optional input masking, differs on Q, K, V + if input_mask is not None: + target_q = target + target_k = target * input_mask.unsqueeze(-1) + target_v = target_k + else: + target_q, target_k, target_v = target, target, target + + x = self.wrap_att( + inputs=[target_q, target_k, target_v], att_mask=decoder_att_mask + ) + x = self.wrap_cross(inputs=[x, memory, memory], att_mask=encoder_att_mask) + x = self.wrap_ff(inputs=[x]) + + return x diff --git a/pkgs/xformers/factory/hydra_helper.py b/pkgs/xformers/factory/hydra_helper.py new file mode 100644 index 0000000..f94557e --- /dev/null +++ b/pkgs/xformers/factory/hydra_helper.py @@ -0,0 +1,36 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# register components configs into Hydra ConfigStore +# component config classes could be used to validate configs +import logging + +from hydra.core.config_store import ConfigStore +from omegaconf.errors import ValidationError + +from xformers.components.attention import ATTENTION_REGISTRY +from xformers.components.feedforward import FEEDFORWARD_REGISTRY +from xformers.components.positional_embedding import POSITION_EMBEDDING_REGISTRY + +logger = logging.getLogger("xformers") + + +def import_xformer_config_schema(): + """ + Best effort - OmegaConf supports limited typing, so we may fail to import + certain config classes. For example, pytorch typing are not supported. + """ + cs = ConfigStore.instance() + + for k, v in { + "ff": FEEDFORWARD_REGISTRY, + "pe": POSITION_EMBEDDING_REGISTRY, + "attention": ATTENTION_REGISTRY, + }.items(): + for kk in v.keys(): + try: + cs.store(name=f"{kk}_schema", node=v[kk].config, group=f"xformers/{k}") + except ValidationError as e: + logger.debug(f"Error registering {kk}_schema, error: {e}") diff --git a/pkgs/xformers/factory/model_factory.py b/pkgs/xformers/factory/model_factory.py new file mode 100644 index 0000000..d4be388 --- /dev/null +++ b/pkgs/xformers/factory/model_factory.py @@ -0,0 +1,313 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +import torch + +from xformers._deprecation_warning import deprecated_function +from xformers.components import reversible as rv +from xformers.components.residual import ResidualNormStyle, get_deepnorm_coefficients +from xformers.factory.block_configs import ( + xFormerBlockConfig, + xFormerDecoderConfig, + xFormerEncoderConfig, +) +from xformers.factory.block_factory import xFormerDecoderBlock, xFormerEncoderBlock +from xformers.factory.weight_init import get_weight_init_fn, xFormerWeightInit + +logger = logging.getLogger("xformers") + + +@dataclass(init=False) +class xFormerConfig: + """ + The configuration structure to define a full Transformer. + This can include a stack of encoder layers, and a stack of decoder layers. + + It is optionally possible to share the embedding weights in between + the encoder and decoder positional encoding, as proposed for instance by + `Using the Output Embedding to Improve Language Models`, Press et al. + + A full config example is for instance as follows: + + :: + + xformer_config = [ + { + "reversible": False, # Turn on to test the effect of using reversible layers + "block_type": "encoder", + "num_layers": LAYERS, + "dim_model": EMB, + "residual_norm_style": "pre", + "position_encoding_config": { + "name": "vocab", + "seq_len": CONTEXT, + "vocab_size": VOCAB_SIZE, + }, + "multi_head_config": { + "num_heads": NUM_HEADS, + "residual_dropout": RES_DROP, + "use_rotary_embeddings": True, + "attention": { + "name": ATTENTION_MECHANISM_STR, + "dropout": ATTN_DROP, + "causal": True, + "seq_len": CONTEXT, + }, + }, + "feedforward_config": { + "name": "FusedMLP", # Use MLP if Triton is not available + "dropout": MLP_DROP, + "activation": "gelu", + "hidden_layer_multiplier": MLP_MULTIPLIER, + }, + } + ] + + + .. _`Using the Output Embedding to Improve Language Models`: https://arxiv.org/pdf/1608.05859.pdf + """ + + stack_configs: Union[List[xFormerBlockConfig], Dict[str, xFormerBlockConfig]] + tie_embedding_weights: bool = False + weight_init: xFormerWeightInit = xFormerWeightInit.ViT + + def __init__( + self, + stack_configs: Union[List[Dict[str, Any]], Dict[str, Dict[str, Any]]], + tie_embedding_weights: bool = False, + weight_init: xFormerWeightInit = xFormerWeightInit.ViT, + ): + # Type all the configurations. Possible typos are caught here + if isinstance(stack_configs, dict): + self.stack_configs = {} + for k, config in stack_configs.items(): + if config["block_type"] == "encoder": + self.stack_configs[k] = xFormerEncoderConfig(**config) + else: + self.stack_configs[k] = xFormerDecoderConfig(**config) + else: + self.stack_configs = [] + for config in stack_configs: + if config["block_type"] == "encoder": + self.stack_configs.append(xFormerEncoderConfig(**config)) + else: + self.stack_configs.append(xFormerDecoderConfig(**config)) + + self.tie_embedding_weights = tie_embedding_weights + self.weight_init = weight_init + deprecated_function(self) + + +class xFormer(torch.nn.Module): + def __init__( + self, + stack_configs: Union[ + xFormerBlockConfig, List[xFormerBlockConfig], Dict[str, xFormerBlockConfig] + ], + tie_embedding_weights: bool = False, + weight_init: xFormerWeightInit = xFormerWeightInit.ViT, + ): + """ + Given a serialized configuration, generate the corresponding model. + This is only a helper and can easily be bypassed + """ + super().__init__() + deprecated_function(self) + + if isinstance(stack_configs, Dict): + stack_configs = list(stack_configs.values()) + + # Convenience, users can pass either a list of configs or a single one + if not isinstance(stack_configs, List): + stack_configs = [stack_configs] + + # Sanity checks, some config combinations do not make sense + self._verify_reversible(stack_configs) + self._verify_deepnorm(stack_configs) + + encoders: List[torch.nn.Module] = [] + decoders: List[torch.nn.Module] = [] + + self.reversible_encoder = False + self.rev_enc_pose_encoding = None + + # Unroll the configs and build the model + for config in stack_configs: + # Handle either Encoder or Decoder stacks + builder = ( + xFormerEncoderBlock.from_config + if isinstance(config, xFormerEncoderConfig) + else xFormerDecoderBlock.from_config + ) + recipient = ( + encoders if isinstance(config, xFormerEncoderConfig) else decoders + ) + + # Build up the stack + for i in range(config.num_layers): + # Label where this layer is in the stack + # (for instance useful for the positional encoding, or late layer norm) + if len(recipient) > 0: + config.layer_position.mark_not_first() + + if config != stack_configs[-1] or i < config.num_layers - 1: + config.layer_position.mark_not_last() + + block = builder(config) # type: ignore + + # If reversible: extract the reversible sub-parts, else append the block as-is + if config.reversible: + # WARNING: only one pose encoding is saved here (not Focal Transformer compatible for instance) + assert isinstance(config, xFormerEncoderConfig) + if block.pose_encoding is not None: + self.rev_enc_pose_encoding = block.pose_encoding + self.reversible_encoder = True + + f, g = xFormerEncoderBlock.get_reversible_layer(config) + recipient.append(torch.nn.ModuleList([f, g])) + else: + recipient.append(block) # type: ignore + + # Tie embedding weights, if requested and possible + assert ( + not tie_embedding_weights or not self.reversible_encoder + ), "Reversible layers and tied embeddings is not supported for now" + + if ( + tie_embedding_weights + and encoders + and encoders[0].pose_encoding + and decoders + and decoders[0].pose_encoding + and not config.reversible + ): + logger.info("Tying encoder and decoder embeddings, as requested") + encoders[0].pose_encoding = decoders[0].pose_encoding + + self.encoders: torch.nn.Module = ( + rv.ReversibleSequence(torch.nn.ModuleList(encoders)) + if self.reversible_encoder + else torch.nn.ModuleList(encoders) + ) + self.decoders = torch.nn.ModuleList(decoders) + + use_deepnorm = ( + stack_configs[0].residual_norm_style == ResidualNormStyle.DeepNorm + ) + + assert ( + not use_deepnorm or not self.reversible_encoder + ), "Reversible layers and deepnorm is not supported for now" + + self.init_weights(weight_init=weight_init, use_deep_norm=use_deepnorm) + + @classmethod + def from_config(cls, config: xFormerConfig): + return cls( + config.stack_configs, config.tie_embedding_weights, config.weight_init + ) + + def _verify_reversible(self, stack_configs: List[xFormerBlockConfig]): + reversible = [ + c.reversible + for c in filter(lambda x: x.block_type == "encoder", stack_configs) + ] + + assert all(reversible) or not any(reversible), ( + "All layers need to have the same reversibility setting. " + + f"Currently {reversible}" + ) + + def _verify_deepnorm(self, stack_configs: List[xFormerBlockConfig]): + deepnorm = [ + c.residual_norm_style == ResidualNormStyle.DeepNorm for c in stack_configs + ] + + assert all(deepnorm) or not any(deepnorm), ( + "All layers need to have the same deepnorm setting. " + + f"Currently {deepnorm}" + ) + + def init_weights(self, weight_init: xFormerWeightInit, use_deep_norm: bool): + # The deepnorm weight initialization method requires different gain factors for the encoder + # and decoder, depending on the general model structure (number of respective layers) + if use_deep_norm: + encoder_coefficients, decoder_coefficients = get_deepnorm_coefficients( + encoder_layers=len(self.encoders), decoder_layers=len(self.decoders) # type: ignore + ) + else: + encoder_coefficients, decoder_coefficients = None, None + + encoder_gain = ( + encoder_coefficients.beta if encoder_coefficients is not None else 1.0 + ) + decoder_gain = ( + decoder_coefficients.beta if decoder_coefficients is not None else 1.0 + ) + + # Pick the desired init function + init_fn = get_weight_init_fn(weight_init) + + # Initialize all the encoder weights + for name, module in self.encoders.named_children(): + init_fn(module=module, name=name, gain=encoder_gain) + + for name, module in self.decoders.named_children(): + init_fn(module=module, name=name, gain=decoder_gain) + + def forward( + self, + src: torch.Tensor, + tgt: Optional[torch.Tensor] = None, + encoder_input_mask: Optional[torch.Tensor] = None, + decoder_input_mask: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + + # Encode to latent space if encoder is present + if len(list(self.encoders.parameters())) > 0: + encoders = self.encoders + memory = src.clone() + if isinstance(encoders, torch.nn.ModuleList): + for encoder in encoders: + memory = encoder(memory, input_mask=encoder_input_mask) + else: + if self.rev_enc_pose_encoding: + memory = self.rev_enc_pose_encoding(src) + + # Reversible Encoder + x = torch.cat([memory, memory], dim=-1) + + # Apply the optional input masking + if encoder_input_mask is not None: + if x.dim() - encoder_input_mask.dim() > 1: + encoder_input_mask.unsqueeze(0) + x += encoder_input_mask.unsqueeze(-1) + + x = encoders(x) + memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0) + + if not self.decoders: + return memory + + # If decoder: either use the encoder ouput, or just decode, both options are possible + if len(self.decoders) > 0: + tgt = src.clone() if tgt is None else tgt + + for decoder in self.decoders: + tgt = decoder( + target=tgt, + # pyre-fixme[61]: `memory` is not always initialized here. + memory=memory, + input_mask=decoder_input_mask, + ) + + return tgt + + return None diff --git a/pkgs/xformers/factory/weight_init.py b/pkgs/xformers/factory/weight_init.py new file mode 100644 index 0000000..754d3be --- /dev/null +++ b/pkgs/xformers/factory/weight_init.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# CREDITS: Reusing a lot of code from the Timm repo +# main difference is probably the handling of deepnorm init, and adapting to some xformers specificities +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + +import logging +import math +from enum import Enum +from typing import Callable + +import torch +import torch.nn as nn +from torch.nn.init import ( + _calculate_fan_in_and_fan_out, + _no_grad_trunc_normal_, + _no_grad_uniform_, +) + +logger = logging.getLogger("xformers") + + +_assert_if_not_initialized = False + + +class xFormerWeightInit(str, Enum): + Timm = "timm" + ViT = "vit" + Moco = "moco" + Small = "small" + + +def get_weight_init_fn(init_choice: xFormerWeightInit): + """ + Provide the xFormers factory with weight init routines. + + Supported initializations are: + - Small: follow the method outlined in `Transformer Without Tears`_ + - ViT: follow the initialization in the reference ViT_ codebase + - Timm: follow the initialization in the reference Timm_ codebase + - Moco: follow the initialization in the reference MocoV3_ codebase + + .. _ViT: https://github.com/google-research/vision_transformer + .. _Timm: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py + .. _MocoV3: https://github.com/facebookresearch/moco-v3 + """ + return { + xFormerWeightInit.Timm: _init_weights_vit_timm, + xFormerWeightInit.ViT: _init_weights_vit_jax, + xFormerWeightInit.Moco: _init_weights_vit_moco, + xFormerWeightInit.Small: _init_weights_small, + }[init_choice] + + +# Define pattern matches +def is_ffn(n): + return "feedforward" in n or ("wrap_ff" in n and not n.endswith("norm")) + + +def is_mha_input_projection(n): + return "q_proj" in n or "k_proj" in n or "v_proj" in n + + +# Define distribution helpers +def _small_init_(tensor: torch.Tensor, gain: float = 1.0) -> torch.Tensor: + r"""Fills the input `Tensor` with values according to the method + described in `Transformer Without Tears`_, using a uniform distribution. + + This is a variation of the Xavier init. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-a, a)` where + + .. math:: + a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + 4 * \text{fan\_out}}} + + Also known as Glorot initialization. + + Args: + tensor: an n-dimensional `torch.Tensor` + gain: an optional scaling factor + + .. _`Transformer Without Tears`: https://arxiv.org/abs/1910.05895 + + """ + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + std = gain * math.sqrt(2.0 / float(fan_in + 4 * fan_out)) + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + return _no_grad_uniform_(tensor, -a, a) + + +def _lecun_normal(tensor, gain=1.0): + fan_in, _ = _calculate_fan_in_and_fan_out(tensor) + denom = fan_in + variance = gain / denom + + # constant is stddev of standard normal truncated to (-2, 2) + _no_grad_trunc_normal_( + tensor, + mean=0.0, + std=math.sqrt(variance) / 0.87962566103423978, + a=-2.0, + b=2.0, + ) + + +# Helpers to keep all the functions typesafe, and handle corner cases and common behaviours in one place +def _maybe_init_tensor(module: nn.Module, attr: str, distribution_: Callable, **kwargs): + # Small helper to catch all the corner cases, while staying type safe + if hasattr(module, attr): + maybe_tensor = getattr(module, attr) + if maybe_tensor is not None and isinstance(maybe_tensor, torch.Tensor): + distribution_(maybe_tensor, **kwargs) + + +def _maybe_report_no_init(module, name): + if len(list(module.named_children())) == 0 and ( + hasattr(module, "weight") or hasattr(module, "bias") + ): + # Skip layer norm, this is ok + if isinstance(module, torch.nn.LayerNorm): + return + + # Skip nn.Embedding, we typically initialize it one level up, else Pytorch has a valid default + if isinstance(module, torch.nn.Embedding): + return + + # This is unexpected, warn about a possible unhandled weight + logger.warning( + f"Not initializing weights in {name}, this could be a mistake.\nModule {module}" + ) + + if _assert_if_not_initialized: + assert False, ( + f"Uninitialized weight found in {module}." + + " If you have a custom module, please provide a `init_weights()` method" + ) + + +# Define the different initialization schemes +def _init_weights_vit_jax( + module: nn.Module, + name: str = "", + head_bias: float = 0.0, + gain: float = 1.0, + deepnorm_style: bool = False, + **kwargs, +): + """ViT weight initialization, matching JAX (Flax) impl""" + + if is_ffn(name): + _maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6) + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + + elif is_mha_input_projection(name) or isinstance(module, nn.Linear): + if deepnorm_style and ( + "q_proj" in name.split(".") or "k_proj" in name.split(".") + ): + gain = 1.0 + + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif isinstance(module, nn.Conv2d): + _maybe_init_tensor(module, "weight", _lecun_normal, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif hasattr(module, "init_weights"): + module.init_weights() # type: ignore + + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_vit_jax(child_module, f"{name}.{child_name}", head_bias, gain) + + +def _init_weights_vit_moco( + module: nn.Module, + name: str = "", + gain: float = 1.0, + **kwargs, +): + """ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed""" + + assert ( + "deepnorm_style" not in kwargs.keys() + ), "This initialization method does not support deepnorm" + + if is_ffn(name): + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif is_mha_input_projection(name) or isinstance(module, nn.Linear): + if isinstance(module.weight, torch.Tensor): + val = ( + math.sqrt(6.0 / float(module.weight.shape[0] + module.weight.shape[1])) + * gain + ) + _maybe_init_tensor(module, "weight", nn.init.uniform_, a=-val, b=val) + + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif hasattr(module, "init_weights"): + module.init_weights(gain=gain) # type: ignore + + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_vit_moco(child_module, child_name, gain) + + +def _init_weights_small( + module: nn.Module, + name: str = "", + head_bias: float = 0.0, + gain: float = 1.0, + deepnorm_style: bool = False, + **kwargs, +): + """Follow the `Transformer Without Tears`_ initialization for self-attention""" + + if is_ffn(name): + _maybe_init_tensor(module, "weight", torch.nn.init.xavier_uniform_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.normal_, std=1e-6) + + elif is_mha_input_projection(name) or isinstance(module, nn.Linear): + # "small init" only scales the attention layers init, not the FFN + if deepnorm_style and ( + "q_proj" in name.split(".") or "k_proj" in name.split(".") + ): + gain = 1.0 + + _maybe_init_tensor(module, "weight", _small_init_, gain=gain) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif isinstance(module, nn.Conv2d): + _maybe_init_tensor(module, "weight", _lecun_normal) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + elif hasattr(module, "init_weights"): + module.init_weights() # type: ignore + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_small(child_module, f"{name}.{child_name}", head_bias, gain) + + +def _init_weights_vit_timm( + module: nn.Module, + name: str = "", + gain: float = 1.0, + deepnorm_style: bool = False, + **kwargs, +): + """ + ViT weight initialization, original timm impl (for reproducibility). + + See DeepNet_ for all the DeepNorm specific codepaths + """ + + if isinstance(module, nn.Linear): + if deepnorm_style and ( + "q_proj" in name.split(".") or "k_proj" in name.split(".") + ): + gain = 1 + + std = 0.02 * gain + a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + _maybe_init_tensor( + module, "weight", _no_grad_trunc_normal_, mean=0.0, std=std, a=-a, b=a + ) + _maybe_init_tensor(module, "bias", nn.init.zeros_) + + elif hasattr(module, "init_weights"): + module.init_weights(gain=gain) # type: ignore + else: + _maybe_report_no_init(module, name) + + # Recurse over the children, if the weight init is being handled here + if not hasattr(module, "init_weights"): + for child_name, child_module in module.named_children(): + _init_weights_vit_timm(child_module, child_name, gain) diff --git a/pkgs/xformers/helpers/__init__.py b/pkgs/xformers/helpers/__init__.py new file mode 100644 index 0000000..11aa37d --- /dev/null +++ b/pkgs/xformers/helpers/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from .timm_sparse_attention import TimmSparseAttention # noqa diff --git a/pkgs/xformers/helpers/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/helpers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..244a9a5 Binary files /dev/null and b/pkgs/xformers/helpers/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/helpers/__pycache__/hierarchical_configs.cpython-310.pyc b/pkgs/xformers/helpers/__pycache__/hierarchical_configs.cpython-310.pyc new file mode 100644 index 0000000..8234f3c Binary files /dev/null and b/pkgs/xformers/helpers/__pycache__/hierarchical_configs.cpython-310.pyc differ diff --git a/pkgs/xformers/helpers/__pycache__/test_utils.cpython-310.pyc b/pkgs/xformers/helpers/__pycache__/test_utils.cpython-310.pyc new file mode 100644 index 0000000..6091627 Binary files /dev/null and b/pkgs/xformers/helpers/__pycache__/test_utils.cpython-310.pyc differ diff --git a/pkgs/xformers/helpers/__pycache__/timm_sparse_attention.cpython-310.pyc b/pkgs/xformers/helpers/__pycache__/timm_sparse_attention.cpython-310.pyc new file mode 100644 index 0000000..0163907 Binary files /dev/null and b/pkgs/xformers/helpers/__pycache__/timm_sparse_attention.cpython-310.pyc differ diff --git a/pkgs/xformers/helpers/hierarchical_configs.py b/pkgs/xformers/helpers/hierarchical_configs.py new file mode 100644 index 0000000..60b9564 --- /dev/null +++ b/pkgs/xformers/helpers/hierarchical_configs.py @@ -0,0 +1,124 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +from xformers._deprecation_warning import deprecated_function +from xformers.components.residual import ResidualNormStyle + + +@dataclass +class BasicLayerConfig: + embedding: int + attention_mechanism: str + patch_size: int + stride: int + padding: int + seq_len: int + feedforward: str + normalization: str = "layernorm" + repeat_layer: int = 1 + + +def get_hierarchical_configuration( + layer_base_configs: List[BasicLayerConfig], + residual_norm_style: ResidualNormStyle = ResidualNormStyle.Pre, + use_rotary_embeddings: bool = True, + mlp_multiplier: int = 4, + in_channels: int = 3, + dim_head: Optional[int] = None, +): + """ + A small helper to generate hierarchical xformers configurations, + which correspond for instance to poolformer or swin architectures. + + Contrary to more "classical" Transformer architectures, which conserve the sequence/context + length across layers, hierarchical Transformers trade the sequence length for the embedding dimension + """ + deprecated_function(get_hierarchical_configuration) + + base_config: Dict[str, Any] = { + "block_type": "encoder", + "dim_model": 0, + "use_triton": False, + "residual_norm_style": str(residual_norm_style), + "multi_head_config": { + "num_heads": 1, + "use_rotary_embeddings": use_rotary_embeddings, + "attention": { + "name": "TBD", + }, + }, + "feedforward_config": { + "name": "TBD", + "activation": "gelu", + "hidden_layer_multiplier": mlp_multiplier, + "dropout": 0.0, + }, + "position_encoding_config": { + "name": "learnable", + "seq_len": 0, + "add_class_token": False, + }, + "patch_embedding_config": { + "in_channels": in_channels, + "kernel_size": 0, + "stride": 0, + "padding": 0, + }, + } + + xformers_config = [] + in_channels = in_channels + + for layer_base_config in layer_base_configs: + lc = copy.deepcopy(base_config) + + lc["normalization"] = layer_base_config.normalization + + # Fill in the changing model dimensions + lc["dim_model"] = layer_base_config.embedding + + # Update the patches + lc["patch_embedding_config"] = { + "in_channels": in_channels, + "kernel_size": layer_base_config.patch_size, + "stride": layer_base_config.stride, + "padding": layer_base_config.padding, + } + + # Update the number of channels for the next layer + in_channels = lc["dim_model"] * 1 + + lc["position_encoding_config"]["seq_len"] = layer_base_config.seq_len + + # Fill in the number of heads (defaults to 1) + if dim_head is not None: + lc["multi_head_config"]["num_heads"] = ( + layer_base_config.embedding // dim_head + ) + assert layer_base_config.embedding % dim_head == 0 + + # Fill in the attention mechanism + lc["multi_head_config"]["attention"][ + "name" + ] = layer_base_config.attention_mechanism + + # FIll in the feedforward + lc["feedforward_config"]["name"] = layer_base_config.feedforward + + print(lc) + xformers_config.append(lc) + + # Handle repeated layers (without the patch embeddings) + if layer_base_config.repeat_layer > 1: + lc_repeat = copy.deepcopy(lc) + lc_repeat.pop("patch_embedding_config") + xformers_config += [lc_repeat] * (layer_base_config.repeat_layer - 1) + + return xformers_config diff --git a/pkgs/xformers/helpers/test_utils.py b/pkgs/xformers/helpers/test_utils.py new file mode 100644 index 0000000..6438e3c --- /dev/null +++ b/pkgs/xformers/helpers/test_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import sys +import tempfile + +import torch + +is_windows = False +if sys.platform == "win32": # pytorch on windows uses gloo not ncll + is_windows = True + + +def init_torch_distributed_local(): + if torch.distributed.is_initialized(): + return + + init_url = "file://" + tempfile.mkstemp()[1] + backend = ( + torch.distributed.Backend.NCCL + if torch.cuda.is_available() and not is_windows + else torch.distributed.Backend.GLOO + ) + torch.distributed.init_process_group( + backend=backend, + rank=0, + world_size=1, + init_method=init_url, + ) diff --git a/pkgs/xformers/helpers/timm_sparse_attention.py b/pkgs/xformers/helpers/timm_sparse_attention.py new file mode 100644 index 0000000..d92ef70 --- /dev/null +++ b/pkgs/xformers/helpers/timm_sparse_attention.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.components.attention.core import scaled_dot_product_attention + + +class TimmSparseAttention(torch.nn.Module): + """ + Almost drop-in replacement for timm attention + but using the sparsity-aware scaled_dot_product_attention from xformers + """ + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + attn_mask=None, + ): + super().__init__() + self.num_heads = num_heads + + self.qkv = torch.nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = torch.nn.Dropout(attn_drop) + self.proj = torch.nn.Linear(dim, dim) + self.proj_drop = torch.nn.Dropout(proj_drop) + self.attn_mask = attn_mask + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + qkv = qkv.flatten(1, 2) + + q, k, v = qkv.unbind() + + x = scaled_dot_product_attention( + q, k, v, self.attn_mask, dropout=self.attn_drop + ) + x = x.reshape(B, self.num_heads, N, C // self.num_heads) + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/pkgs/xformers/info.py b/pkgs/xformers/info.py new file mode 100644 index 0000000..e67c939 --- /dev/null +++ b/pkgs/xformers/info.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict + +import torch + +from . import __version__, _cpp_lib, _is_opensource, _is_triton_available, ops +from .ops.common import OPERATORS_REGISTRY + + +def get_features_status() -> Dict[str, str]: + features = {} + for op in OPERATORS_REGISTRY: + status_str = "available" if op.is_available() else "unavailable" + features[f"{op.OPERATOR_CATEGORY}.{op.NAME}"] = status_str + for k, v in ops.swiglu_op._info().items(): + features[f"swiglu.{k}"] = v + features["is_triton_available"] = str(_is_triton_available()) + return features + + +def print_info(): + features = get_features_status() + print(f"xFormers {__version__}") + features["pytorch.version"] = torch.__version__ + if torch.cuda.is_available(): + features["pytorch.cuda"] = "available" + device = torch.cuda.current_device() + cap = torch.cuda.get_device_capability(device) + features["gpu.compute_capability"] = ".".join(str(ver) for ver in cap) + features["gpu.name"] = torch.cuda.get_device_name(device) + else: + features["pytorch.cuda"] = "not available" + + build_info = _cpp_lib._build_metadata + if build_info is None and isinstance( + _cpp_lib._cpp_library_load_exception, _cpp_lib.xFormersInvalidLibException + ): + build_info = _cpp_lib._cpp_library_load_exception.build_info + if build_info is not None: + features["build.info"] = "available" + features["build.cuda_version"] = build_info.cuda_version + features["build.python_version"] = build_info.python_version + features["build.torch_version"] = build_info.torch_version + for k, v in build_info.build_env.items(): + features[f"build.env.{k}"] = v + else: + features["build.info"] = "none" + + try: + features["build.nvcc_version"] = ".".join( + str(v) for v in torch.ops.xformers._nvcc_build_version() + ) + except (RuntimeError, AttributeError): + pass + + if _is_opensource: + features["source.privacy"] = "open source" + else: + features["source.privacy"] = "fairinternal" + + for name, status in features.items(): + print("{:<50} {}".format(f"{name}:", status)) + + +if __name__ == "__main__": + print_info() diff --git a/pkgs/xformers/ops/__init__.py b/pkgs/xformers/ops/__init__.py new file mode 100644 index 0000000..93c23ac --- /dev/null +++ b/pkgs/xformers/ops/__init__.py @@ -0,0 +1,97 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .fmha import ( + AttentionBias, + AttentionOp, + AttentionOpBase, + AttentionOpDispatch, + LowerTriangularMask, + MemoryEfficientAttentionCutlassFwdFlashBwOp, + MemoryEfficientAttentionCutlassOp, + MemoryEfficientAttentionFlashAttentionOp, + MemoryEfficientAttentionOp, + MemoryEfficientAttentionTritonFwdFlashBwOp, + TritonFlashAttentionOp, + memory_efficient_attention, + memory_efficient_attention_backward, + memory_efficient_attention_forward, + memory_efficient_attention_forward_requires_grad, +) +from .indexing import index_select_cat, scaled_index_add +from .rmsnorm import RMSNorm +from .rope_padded import rope_padded +from .swiglu_op import ( + SwiGLU, + SwiGLUEagerOp, + SwiGLUFusedOp, + SwiGLUOp, + SwiGLUOpDispatch, + SwiGLUPackedFusedOp, + swiglu, +) +from .unbind import get_stack_strides, stack_or_none, unbind + +# BW compatibility +AttentionMask = AttentionBias + + +def masked_matmul(a, b, mask=None): + if torch.overrides.has_torch_function((a, b, mask)): + return torch.overrides.handle_torch_function( + masked_matmul, (a, b, mask), a, b, mask + ) + + att = a @ b + + if mask is None: + return att + + if mask.dtype == torch.bool: + if mask.ndim == 2: + mask = mask.unsqueeze(0).expand(att.shape[0], -1, -1) + # mask is presumed false == ignore + att[~mask] = float("-inf") + else: + # mask is presumed additive + att += mask + return att + + +__all__ = [ + "memory_efficient_attention", + "AttentionBias", + "AttentionMask", + "AttentionOp", + "AttentionOpBase", + "AttentionOpDispatch", + "LowerTriangularMask", + "MemoryEfficientAttentionCutlassFwdFlashBwOp", + "MemoryEfficientAttentionCutlassOp", + "MemoryEfficientAttentionFlashAttentionOp", + "MemoryEfficientAttentionOp", + "MemoryEfficientAttentionTritonFwdFlashBwOp", + "memory_efficient_attention_backward", + "memory_efficient_attention_forward", + "memory_efficient_attention_forward_requires_grad", + "RMSNorm", + "SwiGLU", + "SwiGLUEagerOp", + "SwiGLUFusedOp", + "SwiGLUOp", + "SwiGLUOpDispatch", + "SwiGLUPackedFusedOp", + "swiglu", + "TritonFlashAttentionOp", + "unbind", + "stack_or_none", + "get_stack_strides", + "masked_matmul", + "scaled_index_add", + "index_select_cat", + "rope_padded", +] diff --git a/pkgs/xformers/ops/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..7bfd4f3 Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/__pycache__/common.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000..8d1844b Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/common.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/__pycache__/indexing.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/indexing.cpython-310.pyc new file mode 100644 index 0000000..f6ac919 Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/indexing.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/__pycache__/rmsnorm.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/rmsnorm.cpython-310.pyc new file mode 100644 index 0000000..2037d79 Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/rmsnorm.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/__pycache__/rope_padded.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/rope_padded.cpython-310.pyc new file mode 100644 index 0000000..692f605 Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/rope_padded.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/__pycache__/swiglu_op.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/swiglu_op.cpython-310.pyc new file mode 100644 index 0000000..0365c9a Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/swiglu_op.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/__pycache__/unbind.cpython-310.pyc b/pkgs/xformers/ops/__pycache__/unbind.cpython-310.pyc new file mode 100644 index 0000000..6066611 Binary files /dev/null and b/pkgs/xformers/ops/__pycache__/unbind.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/common.py b/pkgs/xformers/ops/common.py new file mode 100644 index 0000000..a32ec3f --- /dev/null +++ b/pkgs/xformers/ops/common.py @@ -0,0 +1,133 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import os +from typing import Any, Dict, List, Type, TypeVar + +import torch +from torch.torch_version import TorchVersion + + +def get_operator(library: str, name: str): + def no_such_operator(*args, **kwargs): + raise RuntimeError( + f"No such operator {library}::{name} - did you forget to build xformers with `python setup.py develop`?" + ) + + try: + return getattr(getattr(torch.ops, library), name) + except (RuntimeError, AttributeError): + return no_such_operator + + +def get_xformers_operator(name: str): + return get_operator("xformers", name) + + +class BaseOperator: + OPERATOR: Any + NAME: str + OPERATOR_CATEGORY: str + + @classmethod + def is_available(cls) -> bool: + if cls.OPERATOR is None or cls.OPERATOR.__name__ == "no_such_operator": + return False + return True + + @classmethod + def operator_flop(cls, *inputs) -> int: + """Calculate number of FLOP given inputs to `OPERATOR`""" + return -1 + + +OPERATORS_REGISTRY: List[Type[BaseOperator]] = [] +FUNC_TO_XFORMERS_OPERATOR: Dict[Any, Type[BaseOperator]] = {} + +ClsT = TypeVar("ClsT") + + +def register_operator(cls: ClsT) -> ClsT: + global OPERATORS_REGISTRY, FUNC_TO_XFORMERS_OPERATOR + OPERATORS_REGISTRY.append(cls) # type: ignore + FUNC_TO_XFORMERS_OPERATOR[cls.OPERATOR] = cls # type: ignore + return cls + + +# post-2.0, avoids a warning +# (`torch.Tensor.storage` will also be deleted in the future) +_GET_TENSOR_STORAGE = getattr(torch.Tensor, "untyped_storage", None) +if _GET_TENSOR_STORAGE is None: # pre-2.0, `untyped_storage` didn't exist + _GET_TENSOR_STORAGE = torch.Tensor.storage + + +def _get_storage_base(x: torch.Tensor) -> int: + return _GET_TENSOR_STORAGE(x).data_ptr() # type: ignore + + +def make_pytorch_cuda_operator(fn: ClsT) -> ClsT: + from .. import get_python_lib + + def render_arg_type(annotation) -> str: + if annotation is torch.Tensor: + return "Tensor" + if annotation is bool: + return "bool" + if annotation is int: + return "int" + if annotation is List[int]: + return "int[]" + if annotation is List[torch.Tensor]: + return "Tensor[]" + assert False, f"Unable to parse annotation: `{annotation}`" + + sign = inspect.signature(fn) # type: ignore + arguments = [ + f"{render_arg_type(arg.annotation)} {arg.name}" + for arg in sign.parameters.values() + ] + op_name = fn.__name__ # type: ignore + definition = f"{op_name}({', '.join(arguments)}) -> {render_arg_type(sign.return_annotation)}" + + xformers_lib = get_python_lib() + xformers_lib.define(definition) + xformers_lib.impl(op_name, fn, "CUDA") + dispatcher_impl = getattr(getattr(torch.ops, xformers_lib.ns), op_name) + + def wrapper(*args, **kwargs): + return dispatcher_impl(*args, **kwargs) + + return wrapper # type: ignore + + +def _has_a_version_of_triton(): + if os.environ.get("XFORMERS_FORCE_DISABLE_TRITON", "0") == "1": + return False + if not torch.cuda.is_available(): + return False + try: + import triton # noqa: F401 + except ImportError: + return False + return True + + +def _has_triton2(): + if not _has_a_version_of_triton(): + return False + import triton + + tv = TorchVersion(triton.__version__) + return tv >= (2, 1) or tv == (2, 0) + + +def _has_triton21(): + if not _has_a_version_of_triton(): + return False + import triton + + tv = TorchVersion(triton.__version__) + return tv >= (2, 1) diff --git a/pkgs/xformers/ops/fmha/__init__.py b/pkgs/xformers/ops/fmha/__init__.py new file mode 100644 index 0000000..19ff7c1 --- /dev/null +++ b/pkgs/xformers/ops/fmha/__init__.py @@ -0,0 +1,474 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Sequence, Tuple, Type, Union +import os +import torch + +from . import cutlass, decoder, flash, small_k, triton, triton_splitk +from .attn_bias import AttentionBias, BlockDiagonalMask, LowerTriangularMask +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + AttentionOp, + AttentionOpBase, + AttentionOpDispatch, + Context, + Gradients, + Inputs, + bmk2bmhk, +) +from .dispatch import _dispatch_bw, _dispatch_fw, _ensure_op_supports_or_raise + +MemoryEfficientAttentionCutlassOp = (cutlass.FwOp, cutlass.BwOp) +MemoryEfficientAttentionCutlassFwdFlashBwOp = (cutlass.FwOp, flash.BwOp) +MemoryEfficientAttentionDecoderOp = (decoder.FwOp, cutlass.BwOp) +MemoryEfficientAttentionTritonFwdFlashBwOp = (triton.FwOp, flash.BwOp) +MemoryEfficientAttentionFlashAttentionOp = (flash.FwOp, flash.BwOp) +MemoryEfficientAttentionOp = (small_k.FwOp, small_k.BwOp) +TritonFlashAttentionOp = (triton.FwOp, triton.BwOp) + + +class _fMHA(torch.autograd.Function): + @staticmethod + # type: ignore + def forward(ctx, op: AttentionOp, *args: Any) -> Any: + inp = Inputs(*args) + op_fw = op[0] if op is not None else None + op_bw = op[1] if op is not None else None + + out, op_ctx = _memory_efficient_attention_forward_requires_grad( + inp=inp, op=op_fw + ) + + # Saving attn_bias is a bit complicated, as the + # torch part should go in `save_for_backward` + if isinstance(inp.attn_bias, torch.Tensor): + attn_bias_tensor = inp.attn_bias + attn_bias_ctx = None + else: + attn_bias_tensor = None + attn_bias_ctx = inp.attn_bias + + ctx.save_for_backward( + inp.query, + inp.key, + inp.value, + op_ctx.q_padded, + op_ctx.k_padded, + op_ctx.v_padded, + op_ctx.o_padded, + op_ctx.out, + op_ctx.lse, + ) + ctx.rng_state = op_ctx.rng_state + ctx.attn_bias_tensor = attn_bias_tensor + if op_ctx.op_bw is not None: + if op_bw is not None and op_bw is not op_ctx.op_bw: + raise ValueError( + f"Specified op_bw={op_bw.NAME}, but forward op " + f"can only run with op_bw={op_ctx.op_bw.NAME}. Please set op_bw=None." + ) + op_bw = op_ctx.op_bw + ctx.op_fw = op_fw + ctx.op_bw = op_bw + ctx.p = inp.p + ctx.use_alibi = inp.use_alibi + ctx.alibi_mode = inp.alibi_mode + ctx.imp_mode = inp.imp_mode + + ctx.scale = inp.scale + ctx.attn_bias_ctx = attn_bias_ctx + ctx.n_args = len(args) + return out + + @staticmethod + def deserialize_bias( + attn_bias_ctx, attn_bias_tensor: Optional[torch.Tensor] + ) -> Any: + if attn_bias_tensor is None: + return attn_bias_ctx + return attn_bias_tensor + + @classmethod + @torch.autograd.function.once_differentiable + def backward(cls, ctx, grad): + # Re-create context + query, key, value, q_padded, k_padded, v_padded, o_padded, out, lse = ctx.saved_tensors + attn_bias_tensor = ctx.attn_bias_tensor + rng_state = ctx.rng_state + inp = Inputs( + query=query, + key=key, + value=value, + attn_bias=cls.deserialize_bias(ctx.attn_bias_ctx, attn_bias_tensor), + p=ctx.p, + scale=ctx.scale, + use_alibi=ctx.use_alibi, + alibi_mode=ctx.alibi_mode, + imp_mode = ctx.imp_mode, + ) + op_ctx = Context( + lse=lse, + out=out, + q_padded=q_padded, + k_padded=k_padded, + v_padded=v_padded, + o_padded=o_padded, + rng_state=rng_state, + ) + grads = _memory_efficient_attention_backward( + ctx=op_ctx, inp=inp, grad=grad, op=ctx.op_bw + ) + return (None, grads.dq, grads.dk, grads.dv, grads.db) + (None,) * ( + ctx.n_args - 2 + ) + + +def memory_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + use_alibi: bool = False, + alibi_mode: int = 1, + imp_mode: int = 0, + *, + op: Optional[AttentionOp] = None, +) -> torch.Tensor: + """Implements the memory-efficient attention mechanism following + `"Self-Attention Does Not Need O(n^2) Memory" `_. + + :Inputs shape: + + - Input tensors must be in format ``[B, M, H, K]``, where B is the batch size, M \ + the sequence length, H the number of heads, and K the embeding size per head + + - If inputs have dimension 3, it is assumed that the dimensions are ``[B, M, K]`` and ``H=1`` + + - Inputs can also be of dimension 5 with GQA - see note below + + - Inputs can be non-contiguous - we only require the last dimension's stride to be 1 + + + :Equivalent pytorch code: + + .. code-block:: python + + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = query @ key.transpose(-2, -1) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + attn = F.dropout(attn, p) + return attn @ value + + :Examples: + + .. code-block:: python + + import xformers.ops as xops + + # Compute regular attention + y = xops.memory_efficient_attention(q, k, v) + + # With a dropout of 0.2 + y = xops.memory_efficient_attention(q, k, v, p=0.2) + + # Causal attention + y = xops.memory_efficient_attention( + q, k, v, + attn_bias=xops.LowerTriangularMask() + ) + + :Supported hardware: + + NVIDIA GPUs with compute capability above 6.0 (P100+), datatype ``f16``, ``bf16`` and ``f32``. + + :EXPERIMENTAL: Using with Multi Query Attention (MQA) and Grouped Query Attention (GQA): + + MQA/GQA is an experimental feature supported only for the forward pass. + If you have 16 heads in query, and 2 in key/value, you can provide 5-dim tensors + in the ``[B, M, G, H, K]`` format, where ``G`` is the number of head groups (here 2), and + ``H`` is the number of heads per group (8 in the example). + + Please note that xFormers will not automatically broadcast the inputs, so you will need + to broadcast it manually before calling `memory_efficient_attention`. + + :GQA/MQA example: + + .. code-block:: python + + import torch + import xformers.ops as xops + + B, M, K = 3, 32, 128 + kwargs = dict(device="cuda", dtype=torch.float16) + q = torch.randn([B, M, 8, K], **kwargs) + k = torch.randn([B, M, 2, K], **kwargs) + v = torch.randn([B, M, 2, K], **kwargs) + out_gqa = xops.memory_efficient_attention( + q.reshape([B, M, 2, 4, K]), + k.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]), + v.reshape([B, M, 2, 1, K]).expand([B, M, 2, 4, K]), + ) + + Raises: + NotImplementedError: if there is no operator available to compute the MHA + ValueError: if inputs are invalid + + :parameter query: Tensor of shape ``[B, Mq, H, K]`` + :parameter key: Tensor of shape ``[B, Mkv, H, K]`` + :parameter value: Tensor of shape ``[B, Mkv, H, Kv]`` + :parameter attn_bias: Bias to apply to the attention matrix - defaults to no masking. \ + For common biases implemented efficiently in xFormers, see :attr:`xformers.ops.fmha.attn_bias.AttentionBias`. \ + This can also be a :attr:`torch.Tensor` for an arbitrary mask (slower). + :parameter p: Dropout probability. Disabled if set to ``0.0`` + :parameter scale: Scaling factor for ``Q @ K.transpose()``. If set to ``None``, the default \ + scale (q.shape[-1]**-0.5) will be used. + :parameter op: The operators to use - see :attr:`xformers.ops.AttentionOpBase`. \ + If set to ``None`` (recommended), xFormers \ + will dispatch to the best available operator, depending on the inputs \ + and options. + :return: multi-head attention Tensor with shape ``[B, Mq, H, Kv]`` + """ + return _memory_efficient_attention( + Inputs( + query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode + ), + op=op, + ) + + +def memory_efficient_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + use_alibi: bool = False, + alibi_mode: int = 1, + imp_mode: int = 0, + *, + op: Optional[Type[AttentionFwOpBase]] = None, +) -> torch.Tensor: + """ + Calculates the forward pass of :attr:`xformers.ops.memory_efficient_attention`. + """ + return _memory_efficient_attention_forward( + Inputs( + query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode + ), + op=op, + ) + + +def memory_efficient_attention_forward_requires_grad( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + use_alibi: bool = False, + alibi_mode: int = 1, + imp_mode: int = 0, + *, + op: Optional[Type[AttentionFwOpBase]] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Returns a tuple (output, lse), where `lse` can be used to compute the backward pass later. + See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments + See :attr:`xformers.ops.memory_efficient_attention_backward` for running the backward pass + """ + if p != 0.0: + raise NotImplementedError( + "dropout is not supported on the non-autograd API." + " If you want to use dropout, please call `memory_efficient_attention` directly" + ) + out, ctx = _memory_efficient_attention_forward_requires_grad( + Inputs( + query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode + ), + op=op, + ) + return out, ctx.lse + + +def memory_efficient_attention_backward( + grad: torch.Tensor, + output: torch.Tensor, + lse: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + use_alibi: bool = False, + alibi_mode: int = 1, + imp_mode: int = 0, + *, + op: Optional[Type[AttentionBwOpBase]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Computes the gradient of the attention. + Returns a tuple (dq, dk, dv) + See :attr:`xformers.ops.memory_efficient_attention` for an explanation of the arguments. + `lse` is the tensor returned by :attr:`xformers.ops.memory_efficient_attention_forward_requires_grad` + """ + if p != 0.0: + raise NotImplementedError( + "dropout is not supported on the non-autograd API." + " If you want to use dropout, please call `memory_efficient_attention` directly" + ) + gradients = _memory_efficient_attention_backward( + Context(out=output, lse=lse), + Inputs( + query=query, key=key, value=value, p=p, attn_bias=attn_bias, scale=scale, use_alibi=use_alibi, alibi_mode=alibi_mode, imp_mode=imp_mode + ), + grad, + op=op, + ) + return (gradients.dq, gradients.dk, gradients.dv) + + +def _memory_efficient_attention( + inp: Inputs, op: Optional[AttentionOp] = None +) -> torch.Tensor: + # fast-path that doesn't require computing the logsumexp for backward computation + if all(x.requires_grad is False for x in [inp.query, inp.key, inp.value]): + return _memory_efficient_attention_forward( + inp, op=op[0] if op is not None else None + ) + + output_shape = inp.normalize_bmhk() + return _fMHA.apply( + op, inp.query, inp.key, inp.value, inp.attn_bias, inp.p, inp.scale, inp.use_alibi, inp.alibi_mode, inp.imp_mode + ).reshape(output_shape) + + +def _memory_efficient_attention_forward( + inp: Inputs, op: Optional[Type[AttentionFwOpBase]] +) -> torch.Tensor: + inp.validate_inputs() + output_shape = inp.normalize_bmhk() + if op is None: + op = _dispatch_fw(inp, False) + else: + _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) + + out, *_ = op.apply(inp, needs_gradient=False) + return out.reshape(output_shape) + + +def _memory_efficient_attention_forward_requires_grad( + inp: Inputs, op: Optional[Type[AttentionFwOpBase]] +) -> Tuple[torch.Tensor, Context]: + inp.validate_inputs() + output_shape = inp.normalize_bmhk() + if op is None: + op = _dispatch_fw(inp, True) + else: + _ensure_op_supports_or_raise(ValueError, "memory_efficient_attention", op, inp) + out = op.apply(inp, needs_gradient=True) + assert out[1] is not None + return (out[0].reshape(output_shape), out[1]) + + +def _memory_efficient_attention_backward( + ctx: Context, inp: Inputs, grad: torch.Tensor, op: Optional[Type[AttentionBwOpBase]] +) -> Gradients: + """Warning: grad/ctx.out is potentially in BMK format""" + inp.validate_inputs() + if grad.ndim != inp.query.ndim or grad.ndim != ctx.out.ndim: + raise ValueError( + "All tensors should be either in BMK (ndim=3) or BMHK (ndim=4) format. \n" + f"grad.shape : {grad.shape} \n" + f"out.shape : {ctx.out.shape} \n" + f"query.shape: {inp.query.shape}" + ) + shape_dq, shape_dk, shape_dv = tuple( + x.shape for x in (inp.query, inp.key, inp.value) + ) + inp.normalize_bmhk() + # LSE has shape [B, H, M] while query has shape [B, M, H, K] + if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0': + if ( + ctx.lse.ndim != 3 + # Dim 0 + or ( + not isinstance(inp.attn_bias, BlockDiagonalMask) + and ctx.lse.shape[0] != inp.query.shape[0] + ) + or ( + isinstance(inp.attn_bias, BlockDiagonalMask) + and ctx.lse.shape[0] != inp.attn_bias.q_seqinfo.seqstart.shape[0] - 1 + ) + # Dim 1 + or ctx.lse.shape[1] != inp.query.shape[2] + # Dim 2 + or ( + not isinstance(inp.attn_bias, BlockDiagonalMask) + and ctx.lse.shape[2] < inp.query.shape[1] + ) + ): + raise ValueError( + "Input tensors have incompatible shapes." + f"lse.shape : {ctx.lse.shape} \n" + f"query.shape : {inp.query.shape}" + ) + grad = bmk2bmhk(grad, 1) + ctx.out = bmk2bmhk(ctx.out, 1) + + if op is None: + op = _dispatch_bw(inp) + else: + _ensure_op_supports_or_raise( + ValueError, "memory_efficient_attention_backward", op, inp + ) + + grads = op.apply(ctx, inp, grad) + grads.dq = grads.dq.reshape(shape_dq) + grads.dk = grads.dk.reshape(shape_dk) + grads.dv = grads.dv.reshape(shape_dv) + return grads + + +ALL_FW_OPS: Sequence[Type[AttentionFwOpBase]] = [ + cutlass.FwOp, + flash.FwOp, + triton.FwOp, + small_k.FwOp, + triton_splitk.FwOp, +] + +ALL_BW_OPS: Sequence[Type[AttentionBwOpBase]] = [ + cutlass.BwOp, + flash.BwOp, + triton.BwOp, + small_k.BwOp, +] + +__all__ = [ + "AttentionBias", + "AttentionOp", + "AttentionOpBase", + "AttentionOpDispatch", + "LowerTriangularMask", + "MemoryEfficientAttentionCutlassFwdFlashBwOp", + "MemoryEfficientAttentionTritonFwdFlashBwOp", + "MemoryEfficientAttentionCutlassOp", + "MemoryEfficientAttentionFlashAttentionOp", + "MemoryEfficientAttentionOp", + "TritonFlashAttentionOp", + "memory_efficient_attention", + "ALL_FW_OPS", + "ALL_BW_OPS", +] diff --git a/pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..4eba824 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc new file mode 100644 index 0000000..8260274 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/attn_bias.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000..333d087 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/common.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc new file mode 100644 index 0000000..782e7b0 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/cutlass.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000..00327e1 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/decoder.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc new file mode 100644 index 0000000..a5f5eeb Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/dispatch.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc new file mode 100644 index 0000000..fe9cb87 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/flash.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc new file mode 100644 index 0000000..981c109 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/small_k.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc new file mode 100644 index 0000000..ef5cff7 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/triton.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc b/pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc new file mode 100644 index 0000000..0218fd1 Binary files /dev/null and b/pkgs/xformers/ops/fmha/__pycache__/triton_splitk.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/fmha/attn_bias.py b/pkgs/xformers/ops/fmha/attn_bias.py new file mode 100644 index 0000000..df6b882 --- /dev/null +++ b/pkgs/xformers/ops/fmha/attn_bias.py @@ -0,0 +1,779 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import math +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union + +import torch + + +class AttentionBias: + """Base class for a custom bias that can be applied \ + as the attn_bias argument in + :attr:`xformers.ops.memory_efficient_attention`. + + That function has the ability to add a tensor, the + attention bias, to the QK^T matrix before it is used + in the softmax part of the attention calculation. + The attention bias tensor with shape + (B or 1, n_queries, number of keys) + can be given as the attn_bias input. + The most common use case is for an attention bias is + to contain only zeros and negative infinities, which forms + a mask so that some queries only attend to some keys. + + Children of this class define alternative things which can + be used as the attn_bias input to define an attention bias which + forms such a mask, for some common cases. + + When using an :attr:`xformers.ops.AttentionBias` + instead of a :attr:`torch.Tensor`, the mask matrix does + not need to be materialized, and can be + hardcoded into some kernels for better performance. + + See: + + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMask` + - :attr:`xformers.ops.fmha.attn_bias.LowerTriangularMaskWithTensorBias` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask` + - :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask` + + """ + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """ + Materializes the bias as a `torch.Tensor`. This is very slow + and we don't attempt to make it fast. Only use for debugging/testing. + + Shape should be like `[*, q_seqlen, k_seqlen]` + """ + raise NotImplementedError() + + +class LowerTriangularMask(AttentionBias): + """ + A lower-triangular (aka causal) mask + + A query Q cannot attend to a key which is farther from the + initial key than Q is from the initial query. + """ + + def __init__(self, *tensor_args, **tensor_kwargs) -> None: + # NOTE: Unused arguments, we keep them for backward compatibility + super().__init__() + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + return torch.triu(tensor, diagonal=1).to(dtype) # type: ignore + + def add_bias(self, bias: torch.Tensor) -> "LowerTriangularMaskWithTensorBias": + return LowerTriangularMaskWithTensorBias(bias) + + +class LowerTriangularMaskWithTensorBias(LowerTriangularMask): + """A lower-triangular (aka causal) mask with an additive bias""" + + def __init__(self, bias: torch.Tensor) -> None: + self._bias = bias + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return super().materialize(shape, dtype=dtype, device=device) + self._bias + + +@dataclass +class _SeqLenInfo: + """ + (Internal) Represents the division of a dimension into blocks. + + For example, to represents a dimension of length 7 divided into + three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`. + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 2, 5, 7] + seqstart: torch.IntTensor([0, 2, 5, 7]) + """ + + seqstart: torch.Tensor + max_seqlen: int + min_seqlen: int + seqstart_py: List[int] + + def to(self, device: torch.device) -> None: + self.seqstart = self.seqstart.to(device, non_blocking=True) + + def intervals(self) -> Iterable[Tuple[int, int]]: + yield from zip(self.seqstart_py, self.seqstart_py[1:]) + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + """ + assert not isinstance(seqlens, torch.Tensor) + seqstart_py = [0] + max_seqlen = -1 + min_seqlen = -1 + for seqlen in seqlens: + min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen + max_seqlen = max(max_seqlen, seqlen) + seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen) + seqstart = torch.tensor(seqstart_py, dtype=torch.int32) + return cls( + max_seqlen=max_seqlen, + min_seqlen=min_seqlen, + seqstart=seqstart, + seqstart_py=seqstart_py, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1: + raise ValueError( + f"Invalid `torch.Tensor` of shape {x.shape}, expected format " + f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n" + f" seqstart: {self.seqstart_py}" + ) + if batch_sizes is None: + batch_sizes = [1] * (len(self.seqstart_py) - 1) + split_chunks = [] + it = 0 + for batch_size in batch_sizes: + split_chunks.append( + self.seqstart_py[it + batch_size] - self.seqstart_py[it] + ) + it += batch_size + return [ + tensor.reshape([bs, -1, *tensor.shape[2:]]) + for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1)) + ] + + +@dataclass +class _PaddedSeqLenInfo(_SeqLenInfo): + """ + (Internal) Represents the division of a dimension into blocks which are + padded out to the same total length. + + For example, to represent a dimension of length 12 with space for + three blocks of length 4, but where the occupied lengths are + 2, 3 and 2, use `from_seqlens_padded([2, 3, 2], 4)`. + + The layout along the dimension is + + 0 ─► block 0 + block 0 + + + 4 ─► block 1 + block 1 + block 1 + + 8 ─► block 2 + block 2 + + + 12 ─► + + The members will be: + max_seqlen: 3 + min_seqlen: 2 + seqstart_py: [0, 4, 8, 12] + seqstart: torch.IntTensor([0, 4, 8, 12]) + seqlen_py: [2, 3, 2] + seqlen: torch.IntTensor([2, 3, 2]) + padding: 4 + """ + + seqlen: torch.Tensor + seqlen_py: Sequence[int] + padding: int + # From parent: seqstart[i] contains the start position + # of the i-th sequence + # seqstart: torch.Tensor + + def __post_init__(self) -> None: + assert len(self.seqstart_py) == len(self.seqlen_py) + 1 + + def to(self, device: torch.device) -> None: + self.seqlen = self.seqlen.to(device, non_blocking=True) + super().to(device) + + def intervals(self) -> Iterable[Tuple[int, int]]: + for (start, _), length in zip(super().intervals(), self.seqlen_py): + yield start, start + length + + @classmethod + def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo": + raise RuntimeError( + "Use either `_SeqLenInfo.from_seqlens` or `_PaddedSeqLenInfo.from_seqlens_padded`" + ) + + @classmethod + def from_seqlens_padded( + cls, seqlens: Sequence[int], padding: int + ) -> "_PaddedSeqLenInfo": + """ + Input tensors are assumed to be in shape [B, M, *] + seqstart = padding * torch.arange(batch_size) + """ + assert not isinstance(seqlens, torch.Tensor) + assert all(seqlen <= padding for seqlen in seqlens) + seqstart_py = list(range(0, len(seqlens) * padding + 1, padding)) + return cls( + seqlen=torch.tensor(seqlens, dtype=torch.int32), + seqlen_py=seqlens, + max_seqlen=max(seqlens), + min_seqlen=min(seqlens), + seqstart=torch.tensor(seqstart_py, dtype=torch.int32), + seqstart_py=seqstart_py, + padding=padding, + ) + + def split( + self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None + ) -> List[torch.Tensor]: + raise NotImplementedError("_PaddedSeqLenInfo.split") + + +@dataclass +class BlockDiagonalMask(AttentionBias): + """ + A block-diagonal mask that can be passed as ``attn_bias`` + argument to :attr:`xformers.ops.memory_efficient_attention`. + + Queries and Keys are each divided into the same number of blocks. + Queries in block i only attend to keys in block i. + + .. figure:: /_static/block_diag_bias.png + + This bias can be used to handle a batch of sequences of + different lengths, via :attr:`BlockDiagonalMask.from_tensor_list` + + :Example: + + .. code-block:: python + + import torch + from xformers.ops import fmha + + K = 16 + dtype = torch.float16 + device = "cuda" + list_x = [ + torch.randn([1, 3, 1, K], dtype=dtype, device=device), + torch.randn([1, 6, 1, K], dtype=dtype, device=device), + torch.randn([1, 2, 1, K], dtype=dtype, device=device), + ] + attn_bias, x = fmha.BlockDiagonalMask.from_tensor_list(list_x) + linear = torch.nn.Linear(K, K * 3).to(device=device, dtype=dtype) + + q, k, v = linear(x).reshape([1, -1, 1, 3, K]).unbind(-2) + out = fmha.memory_efficient_attention(q, k, v, attn_bias=attn_bias) + list_out = attn_bias.split(out) + print(list_out[0].shape) # [1, 3, 1, K] + assert tuple(list_out[0].shape) == (1, 3, 1, K) + + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _SeqLenInfo + _batch_sizes: Optional[Sequence[int]] = None + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return torch.zeros( + shape, + dtype=dtype, + device=device, + ) + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + assert shape[-1] == self.k_seqinfo.seqstart_py[-1], ( + shape[-1], + self.k_seqinfo.seqstart_py[-1], + ) + assert shape[-2] == self.q_seqinfo.seqstart_py[-1], ( + shape[-2], + self.q_seqinfo.seqstart_py[-1], + ) + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_seqlen: Optional[Sequence[int]] = None, + ) -> "BlockDiagonalMask": + """Creates a :attr:`BlockDiagonalMask` from a list of tensors lengths for query and key/value. + + Args: + q_seqlen (Union[Sequence[int], torch.Tensor]): List or tensor of sequence lengths for query tensors + kv_seqlen (Union[Sequence[int], torch.Tensor], optional): List or tensor of sequence lengths for key/value. + (Defaults to ``q_seqlen``.) + Returns: + BlockDiagonalMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + if kv_seqlen is None or q_seqlen == kv_seqlen: + k_seqinfo = q_seqinfo + else: + k_seqinfo = _SeqLenInfo.from_seqlens(kv_seqlen) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + @classmethod + def from_tensor_list( + cls, + tensors: Sequence[torch.Tensor], + ) -> Tuple["BlockDiagonalMask", torch.Tensor]: + """Creates a :attr:`BlockDiagonalMask` from a list of tensors, and returns the tensors + concatenated on the sequence length dimension + + .. figure:: /_static/block_diag_cat_split.png + + See also :attr:`BlockDiagonalMask.split` to split the returned + :attr:`torch.Tensor` back to a list of tensors of varying sequence length + + Args: + tensors (Sequence[torch.Tensor]): A list of tensors of shape ``[B, M_i, *]``. + All tensors should have the same dimension and the same batch size ``B``, but + they can have different sequence length ``M``. + + Returns: + Tuple[BlockDiagonalMask, torch.Tensor]: The corresponding bias for the attention + along with `tensors` concatenated on the sequence length dimension, with shape ``[1, sum_i{M_i}, *]`` + """ + batch_sizes = [tensor.shape[0] for tensor in tensors] + seqlens = [] + for x in tensors: + for _ in range(x.shape[0]): + seqlens.append(x.shape[1]) + block_diag = cls.from_seqlens(seqlens) + block_diag._batch_sizes = batch_sizes + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in tensors) + concat_tensors = torch.cat(tensors_bs1, dim=1) + return block_diag, concat_tensors + + @classmethod + def from_tensor_lists_qkv( + cls, + tensors_q: Sequence[torch.Tensor], + tensors_k: Sequence[torch.Tensor], + tensors_v: Optional[Sequence[torch.Tensor]] = None, + ) -> Tuple["BlockDiagonalMask", torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert len(tensors_q) == len(tensors_k) + assert tensors_v is None or len(tensors_v) == len(tensors_q) + batch_sizes = [tensor.shape[0] for tensor in tensors_q] + q_seqlens, kv_seqlens = [], [] + for i, (q, k) in enumerate(zip(tensors_q, tensors_k)): + assert q.shape[0] == k.shape[0] + q_seqlens += [q.shape[1]] * q.shape[0] + kv_seqlens += [k.shape[1]] * k.shape[0] + assert tensors_v is None or tensors_v[i].shape[:2] == k.shape[:2] + block_diag = cls.from_seqlens(q_seqlens, kv_seqlens) + block_diag._batch_sizes = batch_sizes + return ( + block_diag, + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_q], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_k], dim=1), + torch.cat([x.reshape([1, -1, *x.shape[2:]]) for x in tensors_v], dim=1) + if tensors_v is not None + else None, + ) + + def split_queries(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def split_kv(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + return self.k_seqinfo.split(tensor, self._batch_sizes) + + def split(self, tensor: torch.Tensor) -> Sequence[torch.Tensor]: + """The inverse operation of :attr:`BlockDiagonalCausalMask.from_tensor_list` + + Args: + tensor (torch.Tensor): Tensor of tokens of shape ``[1, sum_i{M_i}, *]`` + + Returns: + Sequence[torch.Tensor]: A list of tokens with possibly different sequence lengths + """ + assert self.q_seqinfo is self.k_seqinfo + return self.q_seqinfo.split(tensor, self._batch_sizes) + + def make_causal(self) -> "BlockDiagonalCausalMask": + """Makes each block causal""" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_causal_from_bottomright(self) -> "BlockDiagonalCausalFromBottomRightMask": + """Makes each block causal with a possible non-causal prefix""" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + ) + + def make_local_attention( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionMask": + """Experimental: Makes each block causal with local attention""" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + def make_local_attention_from_bottomright( + self, window_size: int + ) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + """Experimental: Makes each block causal with local attention, start from bottom right""" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo, + k_seqinfo=self.k_seqinfo, + _batch_sizes=self._batch_sizes, + _window_size=window_size, + ) + + +@dataclass +class BlockDiagonalCausalMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + + Queries and Keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which is farther from the initial key in block i than Q + is from the initial query in block i. + """ + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + return LowerTriangularMask().materialize( + shape, + dtype=dtype, + device=device, + ) + + +@dataclass +class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalMask`, except that each block is causal. + This mask allows for a non-causal prefix + NOTE: Each block should have `num_keys >= num_queries` otherwise the forward pass is not + defined (softmax of vector of `-inf` in the attention) + + Queries and keys are each divided into the same number of blocks. + A query Q in block i cannot attend to a key which is not in block i, + nor one which nearer the final key in block i than Q is to the + final query in block i. + """ + + def __post_init__(self) -> None: + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + num_queries = q_end - q_start + num_keys = k_end - k_start + if num_keys < num_queries: + raise ValueError( + f"Block #{i} has num_keys={num_keys} and num_queries={num_queries}." + " Expected `num_keys >= num_queries`" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + num_queries, num_keys = shape[-2:] + return torch.triu(tensor, diagonal=num_keys - num_queries + 1).to(dtype) # type: ignore + + +@dataclass +class BlockDiagonalCausalWithOffsetPaddedKeysMask(AttentionBias): + """ + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`, + except an offset on causality is allowed for each block and we support padding for k/v + + The keys and values are divided into blocks which are padded out to + the same total length. + For example, if there is space for 12 keys, for three blocks of + max length 4, but we only want to use the first 2, 3 and 2 + of each block, use `kv_padding=4` and `kv_seqlens=[2, 3, 2]`. + The queries are divided into blocks, without padding, of lengths given by + q_seqlen. + + A query Q in block i cannot attend to a key which is not in block i, + nor one which is not in use (i.e. in the padded area), + nor one which is nearer to the final key in block i + than Q is to the final query in block i. + """ + + q_seqinfo: _SeqLenInfo + k_seqinfo: _PaddedSeqLenInfo + causal_diagonal: Any = None # unused. Exists for BC only. + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=float("-inf"), + device=device, + ) + num_queries, num_keys = shape[-2:] + return torch.triu(tensor, diagonal=1 + num_keys - num_queries).to(dtype) # type: ignore + + def materialize( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + """Materialize the attention bias - for debugging & testing""" + if shape[-1] != self.k_seqinfo.seqstart_py[-1]: + raise ValueError("k shapes wrong") + if shape[-2] != self.q_seqinfo.seqstart_py[-1]: + raise ValueError("q shapes wrong") + mask = torch.empty(shape[-2:], dtype=dtype, device=device) + mask.fill_(-math.inf) + for i, ((q_start, q_end), (k_start, k_end)) in enumerate( + zip( + self.q_seqinfo.intervals(), + self.k_seqinfo.intervals(), + ) + ): + mask[q_start:q_end, k_start:k_end] = self._create_block_mask( + (q_end - q_start, k_end - k_start), + dtype=dtype, + device=device, + ) + for _ in range(len(shape) - 2): + mask = mask.unsqueeze(0) + return mask.expand(shape) + + @classmethod + def from_seqlens( + cls, + q_seqlen: Sequence[int], + kv_padding: int, + kv_seqlen: Sequence[int], + causal_diagonal: Any = None, + ) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + """Creates a :attr:`BlockDiagonalCausalWithOffsetPaddedKeysMask` from a list of tensor + lengths for query and key/value. + + Args: + q_seqlen (Sequence[int]): List or tensor of sequence lengths for query tensors + kv_padding (int): Padding for k/v - also an upperbound on each individual key length + kv_seqlen (Sequence[int]): List or tensor of sequence lengths for key/value. + causal_diagonal: unused, for BC only + Returns: + BlockDiagonalCausalWithOffsetPaddedKeysMask + """ + assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), ( + q_seqlen, + kv_seqlen, + ) + q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen) + k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding) + return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo) + + +@dataclass +class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q - self._window_size >= k: + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + + num_queries, num_keys = shape[-2:] + mask = torch.tril(tensor, diagonal=0).to(dtype) # type: ignore + if self._window_size is not None and self._window_size > 0: + mask = torch.triu(mask, diagonal=-self._window_size + 1) + mask = torch.log(mask) + return mask.to(dtype) + + +@dataclass +class BlockDiagonalCausalLocalAttentionFromBottomRightMask( + BlockDiagonalCausalFromBottomRightMask +): + """ + (Experimental feature) + Same as :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask`. + This makes the mask "local" and the attention pattern banded. + + Query i only attends to keys in its block and cannot attend keys further than "window_size" + from it. + """ + + _window_size: int = 0 # forced due to inheritance and default arguments + + def __post_init__(self): + super().__post_init__() + if self._window_size <= 0: + raise ValueError( + f"Expected `window_size > 0`, but window_size={self._window_size}" + ) + q_seqlen = [ + y - x + for x, y in zip( + self.q_seqinfo.seqstart_py[:-1], self.q_seqinfo.seqstart_py[1:] + ) + ] + kv_seqlen = [ + y - x + for x, y in zip( + self.k_seqinfo.seqstart_py[:-1], self.k_seqinfo.seqstart_py[1:] + ) + ] + for q, k in zip(q_seqlen, kv_seqlen): + if q + (q - k) - self._window_size >= k: + raise RuntimeError( + f"No keys are attended in q_seqlen {q} k_seqlen {k} with sliding window {self._window_size}" + ) + materialized = self.materialize((sum(q_seqlen), sum(kv_seqlen))) + if torch.max(materialized, dim=1).values.min() == -float("inf"): + raise RuntimeError("FUCKING FUCK FUCK") + + def _create_block_mask( + self, + shape: Tuple[int, ...], + dtype: torch.dtype = torch.float32, + device: Union[str, torch.device] = "cpu", + ) -> torch.Tensor: + create_as = dtype if dtype is not torch.bfloat16 else torch.float32 + tensor = torch.full( # type: ignore + shape, + dtype=create_as, + fill_value=1, + device=device, + ) + num_queries, num_keys = shape[-2:] + mask = torch.tril(tensor, diagonal=num_keys - num_queries).to(dtype) # type: ignore + if self._window_size is not None: + mask = torch.triu( + mask, diagonal=num_keys - num_queries - self._window_size + 1 + ) + mask = torch.log(mask) + return mask.to(dtype) diff --git a/pkgs/xformers/ops/fmha/common.py b/pkgs/xformers/ops/fmha/common.py new file mode 100644 index 0000000..df1cff3 --- /dev/null +++ b/pkgs/xformers/ops/fmha/common.py @@ -0,0 +1,542 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass +from typing import Any, List, Mapping, Optional, Set, Tuple, Type, Union + +import torch + +from ..._cpp_lib import _built_with_cuda +from ..common import BaseOperator +from .attn_bias import ( + AttentionBias, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) + + +def _is_bias_type_supported_in_BMK(attn_bias_type: Any) -> bool: + # NoneType + if isinstance(None, attn_bias_type): + return True + if attn_bias_type in [LowerTriangularMask, torch.Tensor]: + return True + return False + + +@dataclass +class Inputs: + """ + Stores inputs to the `memory_efficient_attention` operators + """ + + query: torch.Tensor + key: torch.Tensor + value: torch.Tensor + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None + p: float = 0.0 + scale: Optional[float] = None + use_alibi: bool = False + alibi_mode: int = 1 + imp_mode: int = 0 + + @property + def device(self) -> torch.device: + return self.query.device + + @property + def scale_float(self) -> float: + return self.query.shape[-1] ** (-0.5) if self.scale is None else self.scale + + def get_qkv_in_bmghk(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.query.ndim == 5: + return self.query, self.key, self.value + if self.query.ndim == 4: + return ( + self.query.unsqueeze(2), + self.key.unsqueeze(2), + self.value.unsqueeze(2), + ) + if self.value.ndim == 3: + return ( + self.query[:, :, None, None], + self.key[:, :, None, None], + self.value[:, :, None, None], + ) + assert False + + def normalize_bmhk(self) -> Tuple[int, ...]: + if self.query.ndim not in [3, 4, 5]: + raise ValueError( + f"Invalid shape for query: {self.query.shape}. " + "Expected shape [batch, seqlen, head_groups, num_heads_per_group, K]" + ", [batch, seqlen, num_heads, K], or [batch, seqlen, K]." + ) + if self.value.dtype == torch.int32: + # Quantized K/V case, in which the last dims of Q and K are different. + # NB we currently don't have any implementations for quantized KV with + # SUPPORTS_DIFFERENT_VALUE_EMBED. + output_shape = tuple(self.query.shape) + else: + output_shape = (self.query.shape[:-1]) + (self.value.shape[-1],) + # Convert from legacy format + if self.query.ndim == 3: + self.query = self.query.unsqueeze(2) + self.key = self.key.unsqueeze(2) + self.value = self.value.unsqueeze(2) + if isinstance(self.attn_bias, torch.Tensor): + if self.attn_bias.ndim != 3: + raise ValueError( + f"Expected BMK format for attn_bias, but got {self.attn_bias.shape}" + ) + self.attn_bias = self.attn_bias.unsqueeze(1) + return output_shape + + def validate_inputs(self) -> None: + qkv = (self.query, self.key, self.value) + if self.query.ndim not in (3, 4, 5) or any( + x.ndim != self.query.ndim for x in qkv + ): + raise ValueError( + f"Query/Key/Value should all have BMGHK, BMHK, or BMK shape.\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if any(x.device != self.query.device for x in qkv): + raise ValueError("Query/Key/Value should all be on the same device") + quantized_dtypes = self.key.dtype == self.value.dtype == torch.int32 + non_quantized_dtypes = all(x.dtype == self.query.dtype for x in qkv) + if not (quantized_dtypes or non_quantized_dtypes): + raise ValueError( + "Query/Key/Value should either all have the same dtype, or " + "(in the quantized case) Key/Value should have dtype torch.int32\n" + f" query.dtype: {self.query.dtype}\n" + f" key.dtype : {self.key.dtype}\n" + f" value.dtype: {self.value.dtype}" + ) + # Biases with tensors attached are meant to be in BMHK format + # This would require to permute biases/gradients which can be expensive, + # so let's just forbid it - BMK is a legacy format anyway + if self.query.ndim == 3 and not _is_bias_type_supported_in_BMK( + type(self.attn_bias) + ): + raise ValueError( + f"Please provide inputs in BMHK format rather " + f"than BMK when using bias type `{type(self.attn_bias).__name__}`" + ) + attn_bias_t: Optional[torch.Tensor] = None + if isinstance(self.attn_bias, torch.Tensor): + attn_bias_t = self.attn_bias + if isinstance(self.attn_bias, LowerTriangularMaskWithTensorBias): + attn_bias_t = self.attn_bias._bias + if self.query.ndim == 4 and attn_bias_t is not None: + expected_shape = ( + self.query.shape[0], + self.query.shape[2], + self.query.shape[1], + self.key.shape[1], + ) + if attn_bias_t.shape != expected_shape: + raise ValueError( + f"Invalid shape for attention bias: {attn_bias_t.shape} (expected {expected_shape})\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if isinstance(self.attn_bias, BlockDiagonalMask): + if any(x.shape[0] != 1 for x in qkv): + raise ValueError( + f"Expected batch_size=1 when using block-diagonal bias\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}" + ) + if self.p < 0.0 or self.p > 1.0: + raise ValueError(f"Invalid dropout probability: p={self.p}") + # Check that shapes match between inputs + B, Mq = self.query.shape[:2] + K = self.query.shape[-1] + B, Mkv = self.key.shape[:2] + Kv = self.value.shape[-1] + + valid_shapes = True + if self.query.ndim == 3: # BMK + valid_shapes = ( + self.query.shape == (B, Mq, K) + and self.key.shape == (B, Mkv, K) + and self.value.shape == (B, Mkv, Kv) + ) + H = self.query.shape[-2] + if self.query.ndim == 4: # BMHK + quantized_kv_cache = self.value.dtype == torch.int32 + key_embed_dim = Kv if quantized_kv_cache else K + valid_shapes = ( + self.query.shape == (B, Mq, H, K) + and self.key.shape == (B, Mkv, H, key_embed_dim) + and self.value.shape == (B, Mkv, H, Kv) + ) + G = self.query.shape[2] + if self.query.ndim == 5: # BMNHK + valid_shapes = ( + self.query.shape == (B, Mq, G, H, K) + and self.key.shape == (B, Mkv, G, H, K) + and self.value.shape == (B, Mkv, G, H, Kv) + ) + if not valid_shapes: + raise ValueError( + f"Incompatible shapes for attention inputs:\n" + f" query.shape: {self.query.shape}\n" + f" key.shape : {self.key.shape}\n" + f" value.shape: {self.value.shape}\n" + "HINT: We don't support broadcasting, please use `expand` " + "yourself before calling `memory_efficient_attention` if you need to" + ) + + +@dataclass +class Context: + lse: torch.Tensor + out: torch.Tensor + q_padded: Optional[torch.Tensor] = None + k_padded: Optional[torch.Tensor] = None + v_padded: Optional[torch.Tensor] = None + o_padded: Optional[torch.Tensor] = None + op_bw: Optional[Type["AttentionBwOpBase"]] = None + rng_state: Optional[torch.Tensor] = None + + def get_padded_lse(self, pad_to: int, force_pad_inf: bool = False) -> torch.Tensor: + pad_amount = (pad_to - (self.lse.shape[2] % pad_to)) % pad_to + lse = self.lse + if pad_amount > 0: + if force_pad_inf: + lse = lse[:, :, : self.out.shape[1]] + pad_amount = (pad_to - (lse.shape[2] % pad_to)) % pad_to + lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) + elif force_pad_inf and self.out.shape[1] != lse.shape[2]: + lse[:, :, self.out.shape[1] :].fill_(math.inf) + return lse + + +@dataclass +class Gradients: + dq: torch.Tensor + dk: torch.Tensor + dv: torch.Tensor + # bias gradient. None if there is no tensor bias or if it doesn't require grad + db: Optional[torch.Tensor] = None + + +class AttentionOpBase(BaseOperator): + """Base class for any attention operator in xFormers + + See: + + - :attr:`xformers.ops.fmha.cutlass.FwOp` + - :attr:`xformers.ops.fmha.cutlass.BwOp` + - :attr:`xformers.ops.fmha.flash.FwOp` + - :attr:`xformers.ops.fmha.flash.BwOp` + - :attr:`xformers.ops.fmha.triton.FwOp` + - :attr:`xformers.ops.fmha.triton.BwOp` + - :attr:`xformers.ops.fmha.small_k.FwOp` + - :attr:`xformers.ops.fmha.small_k.BwOp` + """ + + OPERATOR: Any + SUPPORTED_DEVICES: Set[str] + CUDA_MINIMUM_COMPUTE_CAPABILITY: Tuple[int, int] = (5, 0) + SUPPORTED_DTYPES: Set[torch.dtype] + SUPPORTED_MAX_K: float + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None)} + SUPPORTS_DROPOUT: bool + SUPPORTS_CUSTOM_SCALE: bool = False + SUPPORTS_DIFFERENT_VALUE_EMBED: bool = False + IS_DETERMINISTIC: bool = True + SUPPORTS_BMGHK: bool = False + NAME: str + OPERATOR_CATEGORY = "memory_efficient_attention" + + _TEST_BATCH_SIZES: List[int] = [1, 300] + _TEST_K: List[int] = [32, 128] + + @classmethod + def supports(cls, d: Inputs) -> bool: + return not cls.not_supported_reasons(d) + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = [] + if not cls.SUPPORTS_DIFFERENT_VALUE_EMBED and K != Kv: + reasons.append("query.shape[-1] != value.shape[-1]") + if max(K, Kv) > cls.SUPPORTED_MAX_K: + reasons.append( + f"max(query.shape[-1] != value.shape[-1]) > {cls.SUPPORTED_MAX_K}" + ) + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + """ + Returns a list of reasons why this is not supported. + The kernel can run these inputs only if the returned list is empty + """ + reasons = cls.shape_not_supported_reasons( + Mq=d.query.shape[1], + Mkv=d.key.shape[1], + K=d.query.shape[-1], + Kv=d.query.shape[-1], + ) + device_type = d.query.device.type + dtype = d.query.dtype + if device_type not in cls.SUPPORTED_DEVICES: + reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") + if device_type == "cuda" and not _built_with_cuda: + reasons.append("xFormers wasn't build with CUDA support") + if device_type == "cuda": + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: + reasons.append( + f"requires device with capability > {cls.CUDA_MINIMUM_COMPUTE_CAPABILITY} " + f"but your GPU has capability {device_capability} (too old)" + ) + if dtype not in cls.SUPPORTED_DTYPES: + reasons.append(f"dtype={dtype} (supported: {cls.SUPPORTED_DTYPES})") + if type(d.attn_bias) not in cls.SUPPORTED_ATTN_BIAS_TYPES: + reasons.append(f"attn_bias type is {type(d.attn_bias)}") + if (d.p != 0.0) and not cls.SUPPORTS_DROPOUT: + reasons.append("dropout > 0.0") + if d.scale is not None and not cls.SUPPORTS_CUSTOM_SCALE: + reasons.append("has custom scale") + # bfloat16 is only supported on A100+ + # ... although the kernels can still run and give the + # correct result + if dtype is torch.bfloat16 and ( + not device_type.startswith("cuda") + ): + reasons.append("bf16 is only supported on A100+ GPUs") + if not cls.is_available(): + reasons.append( + "operator wasn't built - see `python -m xformers.info` for more info" + ) + if not cls.IS_DETERMINISTIC and torch.are_deterministic_algorithms_enabled(): + reasons.append( + "operator is non-deterministic, but `torch.use_deterministic_algorithms` is set" + ) + if not cls.SUPPORTS_BMGHK and d.query.ndim == 5: + reasons.append("operator does not support BMGHK format") + return reasons + + +class AttentionFwOpBase(AttentionOpBase): + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 3e-4, + torch.half: 4e-3, + torch.bfloat16: 2e-2, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 2e-5, + torch.half: 4e-4, + torch.bfloat16: 5e-3, + } + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + raise NotImplementedError() + + @classmethod + def attn_operator_flop( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + causal: bool = False, + seqstart_k: Optional[torch.Tensor] = None, + seqstart_q: Optional[torch.Tensor] = None, + ) -> int: + """ + Computes total flops for the attention + Assumes inputs in format BMHK + """ + assert query.ndim == 4 + + if seqstart_q is not None: + seqstart_q_py = seqstart_q.tolist() + else: + seqstart_q_py = [0, query.shape[1]] + if seqstart_k is not None: + seqstart_k_py = seqstart_k.tolist() + else: + seqstart_k_py = [0, key.shape[1]] + + total_flop = 0 + for q_start, q_end, k_start, k_end in zip( + seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] + ): + num_q = q_end - q_start + num_kv = k_end - k_start + # (M,K) @ (K,N) GEMM needs M*N*K*2 flop + # Q @ K.transpose + total_flop += num_q * num_kv * query.shape[-1] * 2 + # (ignore softmax) + # attn @ V + total_flop += num_q * key.shape[-1] * num_kv * 2 + # Multiply by num_heads and batches + total_flop = total_flop * value.shape[2] * value.shape[0] + if causal: + total_flop //= 2 + return total_flop + + +class AttentionBwOpBase(AttentionOpBase): + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 5e-4, + torch.half: 9e-2, + torch.bfloat16: 0.7, + } + ERROR_RTOL: Mapping[torch.dtype, float] = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 0.1, + } + SUPPORTS_ATTN_BIAS_GRAD = False + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(AttentionBwOpBase, cls).not_supported_reasons(d) + if ( + isinstance(d.attn_bias, torch.Tensor) + and d.attn_bias.requires_grad + and not cls.SUPPORTS_ATTN_BIAS_GRAD + ): + reasons.append( + "Computing the bias gradient is not supported (attn_bias.requires_grad = True)" + ) + + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + raise NotImplementedError() + + @classmethod + def attn_operator_flop( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + causal: bool = False, + seqstart_k: Optional[torch.Tensor] = None, + seqstart_q: Optional[torch.Tensor] = None, + ) -> int: + """ + Computes total flops for the attention + Assumes inputs in format BMHK + """ + assert query.ndim == 4 + + if seqstart_q is not None: + seqstart_q_py = seqstart_q.tolist() + else: + seqstart_q_py = [0, query.shape[1]] + if seqstart_k is not None: + seqstart_k_py = seqstart_k.tolist() + else: + seqstart_k_py = [0, key.shape[1]] + + total_flop = 0 + for q_start, q_end, k_start, k_end in zip( + seqstart_q_py, seqstart_q_py[1:], seqstart_k_py, seqstart_k_py[1:] + ): + num_q = q_end - q_start + num_kv = k_end - k_start + Kqk = query.shape[-1] + Kv = value.shape[-1] + # (M,K) @ (K,N) GEMM needs M*N*K*2 flop + # att = Q @ K.transpose + total_flop += num_q * num_kv * Kqk * 2 + # att @ dO + total_flop += num_kv * num_q * Kv * 2 + # dov = dO @ V + total_flop += num_q * Kv * num_kv * 2 + # dov @ K + total_flop += num_q * Kqk * num_kv * 2 + # dov @ Q + total_flop += num_q * Kqk * num_kv * 2 + # Multiply by num_heads and batches + total_flop = total_flop * value.shape[2] * value.shape[0] + if causal: + total_flop //= 2 + return total_flop + + +AttentionOp = Tuple[ + Optional[Type[AttentionFwOpBase]], Optional[Type[AttentionBwOpBase]] +] + + +@dataclass +class AttentionOpDispatch: + """Dispatcher to automatically select + the best operator to run memory-efficient attention. + + :Deprecated: + + This class is deprecated and will be removed in a later version + """ + + op: AttentionOp + + @classmethod + def from_arguments( + cls, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None, + p: float = 0.0, + scale: Optional[float] = None, + ) -> "AttentionOpDispatch": + """Here for backward compatibility""" + from .dispatch import _dispatch_bw, _dispatch_fw + + inp = Inputs( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=p, + scale=scale, + ) + return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp))) + + +def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor: + if tensor.ndim == 4: + return tensor + return tensor.reshape([-1, num_heads, tensor.shape[1], tensor.shape[2]]).permute( + (0, 2, 1, 3) + ) + + +def check_lastdim_alignment_stride1( + reasons: List[str], name: str, x: torch.Tensor, alignment: int +) -> None: + if x.shape[-1] % alignment != 0: + reasons.append(f"{name}.shape[-1] % {alignment} != 0") + elif x.stride(-2) % alignment != 0: + reasons.append( + f"{name}.stride(-2) % {alignment} != 0 ({name}.stride() = {x.stride()})" + ) + # We can have stride=0 sometimes if dimension=1 + if x.stride(-1) > 1: + reasons.append( + f"{name}.stride(-1) > 1 ({name}.stride() = {x.stride()}) - you should call `.contiguous()` on the input" + ) diff --git a/pkgs/xformers/ops/fmha/cutlass.py b/pkgs/xformers/ops/fmha/cutlass.py new file mode 100644 index 0000000..37058a9 --- /dev/null +++ b/pkgs/xformers/ops/fmha/cutlass.py @@ -0,0 +1,479 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from enum import Enum +from typing import Any, List, Mapping, Optional, Set, Tuple, Union + +import torch + +from ..common import get_xformers_operator, register_operator +from . import attn_bias +from .attn_bias import ( + AttentionBias, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalMask, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + check_lastdim_alignment_stride1, +) + + +def _uses_tensorcores(sm: int, is_half: bool) -> bool: + if sm >= 80: + return True + if sm >= 70: + return is_half + return False + + +def _minimum_gemm_alignment(inp: Inputs) -> int: + if inp.device.type != "cuda": + return 1 + cap = torch.cuda.get_device_capability(inp.device) + sm = cap[0] * 10 + cap[1] + bits_per_scalar = {torch.float: 32, torch.half: 16, torch.bfloat16: 16}[ + inp.query.dtype + ] + uses_tensorcores = _uses_tensorcores(sm, bits_per_scalar == 16) + matmul_alignment_mn = 1 + if sm >= 80: + matmul_alignment_mn = 4 + if uses_tensorcores: + matmul_alignment_mn = max(matmul_alignment_mn, 128 // bits_per_scalar) + return matmul_alignment_mn + + +def _get_seqlen_info( + inp: Inputs, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], int, int]: + attn_bias = inp.attn_bias + if isinstance( + attn_bias, (BlockDiagonalMask, BlockDiagonalCausalWithOffsetPaddedKeysMask) + ): + attn_bias.k_seqinfo.to(inp.query.device) + attn_bias.q_seqinfo.to(inp.query.device) + seqstart_k = attn_bias.k_seqinfo.seqstart + seqstart_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + else: + seqstart_k = None + seqstart_q = None + max_seqlen_q = -1 + max_seqlen_k = -1 + + return seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k + + +def _get_tensor_bias( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Optional[torch.Tensor]: + if isinstance(attn_bias, torch.Tensor): + return attn_bias + elif isinstance(attn_bias, LowerTriangularMaskWithTensorBias): + return attn_bias._bias + return None + + +def _check_bias_alignment( + reasons: List[str], attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> None: + attn_bias_tensor = _get_tensor_bias(attn_bias) + if attn_bias_tensor is not None: + alignment = 128 // torch.finfo(attn_bias_tensor.dtype).bits + show_padding_hint = False + for d in range(attn_bias_tensor.ndim - 1): + if attn_bias_tensor.stride(d) % alignment != 0: + reasons.append( + f"attn_bias.stride(-2) % {alignment} != 0 (attn_bias.stride() = {attn_bias_tensor.stride()})" + ) + show_padding_hint = True + if show_padding_hint: + reasons.append( + """\ +HINT: To use an `attn_bias` with a sequence length that is not a multiple of 8, \ +you need to ensure memory is aligned by slicing a bigger tensor. \ +Example: use `attn_bias = torch.zeros([1, 1, 5, 8])[:,:,:,:5]` instead of `torch.zeros([1, 1, 5, 5])`""" + ) + # We can have stride=0 sometimes if dimension=1 + if attn_bias_tensor.stride(-1) > 1: + reasons.append( + f"attn_bias.stride(-1) > 1 (attn_bias.stride() = {attn_bias_tensor.stride()}) - " + "you should call `.contiguous()` on the bias" + ) + + +class _CustomMaskType(int, Enum): + """ + (Matches CustomMaskType in C++.) + """ + + NoCustomMask = 0 + CausalFromTopLeft = 1 + CausalFromBottomRight = 2 + + +def _custom_mask_type(bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: + if isinstance( + bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + ), + ): + return int(_CustomMaskType.CausalFromTopLeft) + if isinstance( + bias, + ( + attn_bias.BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ): + return int(_CustomMaskType.CausalFromBottomRight) + return int(_CustomMaskType.NoCustomMask) + + +@register_operator +class FwOp(AttentionFwOpBase): + """xFormers' MHA kernel based on CUTLASS. + Supports a large number of settings (including without TensorCores, f32 ...) + and GPUs as old as P100 (Sm60) + """ + + OPERATOR = get_xformers_operator("efficient_attention_forward_cutlass") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.float, torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 65536 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + torch.Tensor, + LowerTriangularMask, + LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalWithOffsetPaddedKeysMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + } + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = True + SUPPORTS_BMGHK = True + NAME = "cutlassF" + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + if inp.query.ndim in [3, 4]: + return cls.apply_bmhk(inp, needs_gradient=needs_gradient) + assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" + + # Workaround until this is properly implemented in C++ + # run each head group in a different stream + n_groups = inp.key.shape[2] + main_stream = torch.cuda.current_stream() + streams = [main_stream] + [ + torch.cuda.Stream(device=inp.query.device) for _ in range(n_groups - 1) + ] + outs = [] + for group, stream in enumerate(streams): + stream.wait_stream(main_stream) + with torch.cuda.stream(stream): + query = inp.query[:, :, group] + key = inp.key[:, :, group] + value = inp.value[:, :, group] + bias = inp.attn_bias + if isinstance(bias, torch.Tensor): + bias = bias[:, group] + if isinstance(bias, attn_bias.LowerTriangularMaskWithTensorBias): + bias = attn_bias.LowerTriangularMaskWithTensorBias( + bias._bias[:, group] + ) + outs.append( + cls.apply_bmhk( + replace(inp, query=query, key=key, value=value, attn_bias=bias), + needs_gradient=needs_gradient, + ) + ) + for s in streams[1:]: + main_stream.wait_stream(s) + out = torch.stack([o[0] for o in outs], dim=2) + ctx: Optional[Context] = None + if needs_gradient: + ctx = Context( + out=out, + lse=torch.stack([o[1].lse for o in outs], dim=1), # type: ignore + op_bw=outs[0][1].op_bw, # type: ignore + ) + return out, ctx + + @classmethod + def apply_bmhk( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if type(inp.attn_bias) not in FwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + seqstart_k, seqstart_q, max_seqlen_q, _ = _get_seqlen_info(inp) + out, lse, rng_seed, rng_offset = cls.OPERATOR( + query=inp.query, + key=inp.key, + value=inp.value, + attn_bias=_get_tensor_bias(inp.attn_bias), + seqstart_q=seqstart_q, + seqstart_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + dropout_p=inp.p, + compute_logsumexp=needs_gradient, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + seqlen_k=inp.attn_bias.k_seqinfo.seqlen + if isinstance(inp.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + else None, + window_size=inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ) + else None, + ) + ctx: Optional[Context] = None + if needs_gradient: + ctx = Context( + out=out, + lse=lse, + # cutlass forward is only compatible with cutlass backward if + # dropout is used (because of the way RNG states are passed and the + # way random numbers are generated during backward) + op_bw=BwOp if inp.p != 0 else None, + ) + if inp.p != 0: + ctx.rng_state = torch.tensor( + [rng_seed, rng_offset], dtype=torch.int64, device="cpu" + ) + return out, ctx + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + return reasons + + @classmethod + # type: ignore + def operator_flop( + cls, + q, + k, + v, + b, + seqstart_q, + seqstart_k, + max_seqlen_q_, + compute_lse, + custom_mask_type, + *a, + ) -> int: + return cls.attn_operator_flop( + q, + k, + v, + causal=custom_mask_type > 0, + seqstart_k=seqstart_k, + seqstart_q=seqstart_q, + ) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_xformers_operator("efficient_attention_backward_cutlass") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + torch.Tensor, + LowerTriangularMask, + # TODO: Fix handling of gradient through the fMHA autograd function + # LowerTriangularMaskWithTensorBias, + BlockDiagonalMask, + BlockDiagonalCausalMask, + attn_bias.BlockDiagonalCausalFromBottomRightMask, + attn_bias.BlockDiagonalCausalLocalAttentionMask, + } + SUPPORTS_ATTN_BIAS_GRAD = True + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + NAME = "cutlassB" + + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 5e-4, + # increased from 9e-2, more opportunities for numerical errors when bias is + # used, noticed in gK on SM80 + torch.half: 1e-1, + torch.bfloat16: 7e-1, + } + + _TEST_K: List[int] = [ + 32, # 64x64 kernel + 128, # 64x128/128x128 kernel + 256, # 64x128 with accumulation in gmem + ] + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + matmul_alignment_mn = _minimum_gemm_alignment(d) + + check_lastdim_alignment_stride1(reasons, "query", d.query, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "key", d.key, matmul_alignment_mn) + check_lastdim_alignment_stride1(reasons, "value", d.value, matmul_alignment_mn) + _check_bias_alignment(reasons, d.attn_bias) + attn_bias_tensor = _get_tensor_bias(d.attn_bias) + + # Backprop of gradient through broadcasted bias is not supported + if attn_bias_tensor is not None and attn_bias_tensor.requires_grad: + # Don't forget that inputs are either in BMK or BMHK! + if d.query.ndim == 3 and attn_bias_tensor.ndim == 3: + expected_bias_shape = (*d.query.shape[:2], d.key.shape[1]) + else: + # bias is B H Mq Mk + expected_bias_shape = ( + d.query.shape[0], + d.query.shape[2] if d.query.ndim == 4 else 1, + d.query.shape[1], + d.key.shape[1], + ) + if tuple(attn_bias_tensor.shape) != expected_bias_shape: + reasons.append( + "Broadcasting the `attn_bias` tensor is not supported " + f"(shape: {tuple(attn_bias_tensor.shape)}" + f"/ expected: {expected_bias_shape})" + ) + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + if type(inp.attn_bias) not in BwOp.SUPPORTED_ATTN_BIAS_TYPES: + raise NotImplementedError("Unsupported attn_bias type") + + seqstart_k, seqstart_q, max_seqlen_q, max_seqlen_k = _get_seqlen_info(inp) + dtype = inp.query.dtype + + rng_seed = rng_offset = 0 + if inp.p != 0.0: + if ( + ctx.rng_state is None + or ctx.rng_state.dtype != torch.int64 + or ctx.rng_state.device.type != "cpu" + or ctx.rng_state.shape != (2,) + ): + raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}") + rng_seed, rng_offset = ctx.rng_state.tolist() + + force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5) + (grad_q, grad_k, grad_v, grad_bias) = cls.OPERATOR( + grad.to(dtype), + inp.query, + inp.key, + inp.value, + _get_tensor_bias(inp.attn_bias), + cu_seqlens_q=seqstart_q, + cu_seqlens_k=seqstart_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + logsumexp=ctx.get_padded_lse(32, force_pad_inf=force_pad_inf), + output=ctx.out.to(dtype), + dropout_p=inp.p, + # if not using dropout, seed and offset are irrelevant but still expected + # in function signature so just pass 0 + # seed and offset could be None if a different FW op other than cutlass + # was used. + rng_seed=rng_seed, + rng_offset=rng_offset, + custom_mask_type=_custom_mask_type(inp.attn_bias), + scale=inp.scale, + num_splits_key=-1, # Let C++ determine it + window_size=inp.attn_bias._window_size + if isinstance( + inp.attn_bias, + ( + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ) + else None, + ) + + # c++/CUDA implementation returns an uninitialized tensor if bias doesn't + # require grad + if not ( + isinstance(inp.attn_bias, torch.Tensor) and inp.attn_bias.requires_grad + ): + grad_bias = None + + return Gradients(dq=grad_q, dk=grad_k, dv=grad_v, db=grad_bias) + + @classmethod + # type: ignore + def operator_flop( + cls, + dO, + q, + k, + v, + b, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + logsumexp, + output, + dropout_p, + rng_seed, + rng_offset, + custom_mask_type, + scale, + ) -> int: + return cls.attn_operator_flop( + q, + k, + v, + seqstart_q=cu_seqlens_q, + seqstart_k=cu_seqlens_k, + causal=custom_mask_type > 0, + ) diff --git a/pkgs/xformers/ops/fmha/decoder.py b/pkgs/xformers/ops/fmha/decoder.py new file mode 100644 index 0000000..8f52c94 --- /dev/null +++ b/pkgs/xformers/ops/fmha/decoder.py @@ -0,0 +1,108 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Optional, Set, Tuple + +import numpy as np +import torch + +from ..common import get_xformers_operator, register_operator +from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from .common import AttentionFwOpBase, Context, Inputs + + +@register_operator +class FwOp(AttentionFwOpBase): + """An operator optimized for very small values of K (``K <= 32``) \ + and f32 pre-Ampere as it does not use TensorCores. + Only supports contiguous inputs in BMK format, so an extra reshape \ + or contiguous call might be done. + + :Deprecated: + + This operator is deprecated and should not be used in new code + """ + + OPERATOR = get_xformers_operator("efficient_attention_forward_decoder") + SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = {torch.bfloat16, torch.half, torch.float32} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (7, 0) + SUPPORTED_MAX_K: float = 128 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {BlockDiagonalCausalWithOffsetPaddedKeysMask} + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + NAME = "decoderF" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + + attn_bias = d.attn_bias + if isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + # If we don't get here, we've an error elsewhere + if d.query.ndim != 4 or d.key.ndim != 4: + reasons.append("Inputs must be BMHK. BMK not supported") + + if d.query.shape[0] != 1: + reasons.append("One formal batch element expected") + + if d.query.shape[-1] != 128: + reasons.append("Only head_dim==128 for now.") + + if d.key.stride(-1) != 1: + reasons.append("expect keys to have last dim contiguous") + + if d.value.stride(-1) != 1: + reasons.append("expect values to have last dim contiguous") + + q_starts = attn_bias.q_seqinfo.seqstart_py + if attn_bias.q_seqinfo.max_seqlen != 1: + reasons.append("decoding expects one query") + elif d.query.shape[1] != len(q_starts) - 1: + reasons.append("empty lanes not supported yet") + + if attn_bias.k_seqinfo.padding > 8192: + reasons.append("key padding exceeds 8192") + + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if needs_gradient: + raise NotImplementedError("gradient") + attn_bias = inp.attn_bias + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + + attn_bias.k_seqinfo.to(inp.query.device) + attn_bias.q_seqinfo.to(inp.query.device) + + padding = attn_bias.k_seqinfo.padding + multiquery = inp.key.stride(2) == 0 + if multiquery: + key = inp.key[0, :, :1].unflatten(0, (-1, padding)) + value = inp.value[0, :, :1].unflatten(0, (-1, padding)) + else: + key = inp.key[0].unflatten(0, (-1, padding)) + value = inp.value[0].unflatten(0, (-1, padding)) + + seq_positions = attn_bias.k_seqinfo.seqlen + + query = inp.query[0, :, None] + + if inp.scale is not None: + qk_scale = inp.scale + else: + qk_scale = 1.0 / np.sqrt(key.shape[-1]) + + out = cls.OPERATOR( + query=query, + key=key, + value=value, + seq_positions=seq_positions, + scale=qk_scale, + ) + return out, None diff --git a/pkgs/xformers/ops/fmha/dispatch.py b/pkgs/xformers/ops/fmha/dispatch.py new file mode 100644 index 0000000..82be053 --- /dev/null +++ b/pkgs/xformers/ops/fmha/dispatch.py @@ -0,0 +1,147 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import textwrap +from collections import deque +from typing import List, Sequence, Type, TypeVar + +from . import attn_bias, cutlass, decoder, flash, small_k, triton, triton_splitk +from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs + + +def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool: + return False + + +def _is_triton_fwd_fastest(inp: Inputs) -> bool: + # TODO: fill out + return False + + +T = TypeVar("T", Type[AttentionFwOpBase], Type[AttentionBwOpBase]) + + +def _format_inputs_description(inp: Inputs) -> str: + return f"""query : shape={tuple(inp.query.shape)} ({inp.query.dtype}) +key : shape={tuple(inp.key.shape)} ({inp.key.dtype}) +value : shape={tuple(inp.value.shape)} ({inp.value.dtype}) +attn_bias : {type(inp.attn_bias)} +p : {inp.p}""" + + +def _ensure_op_supports_or_raise(exc_type, name: str, op, inp: Inputs) -> None: + reasons = op.not_supported_reasons(inp) + if not reasons: + return + raise exc_type( + f"""Operator `{name}` does not support inputs: +{textwrap.indent(_format_inputs_description(inp), ' ')} +{_format_not_supported_reasons(op, reasons)}""" + ) + + +def _format_not_supported_reasons(op, reasons: List[str]) -> str: + return f"`{op.NAME}` is not supported because:\n " + "\n ".join(reasons) + + +def _run_priority_list(name: str, priority_list: Sequence[T], inp: Inputs) -> T: + not_supported_reasons: List[List[str]] = [] + for op in priority_list: + not_supported = op.not_supported_reasons(inp) + if not not_supported: + return op + not_supported_reasons.append(not_supported) + + # Let's write a nice message explaining what we tried and why it's not supported + msg = f"""No operator found for `{name}` with inputs: +{textwrap.indent(_format_inputs_description(inp), ' ')}""" + for op, not_supported in zip(priority_list, not_supported_reasons): + msg += "\n" + _format_not_supported_reasons(op, not_supported) + raise NotImplementedError(msg) + + +def _dispatch_fw_priority_list( + inp: Inputs, needs_gradient: bool +) -> Sequence[Type[AttentionFwOpBase]]: + priority_list_ops = deque( + [ + flash.FwOp, + triton.FwOp, + cutlass.FwOp, + small_k.FwOp, + ] + ) + if _is_cutlass_fwd_faster_than_flash(inp): + priority_list_ops.remove(cutlass.FwOp) + priority_list_ops.appendleft(cutlass.FwOp) + if _is_triton_fwd_fastest(inp): + priority_list_ops.remove(triton.FwOp) + priority_list_ops.appendleft(triton.FwOp) + if not needs_gradient: + mqa_or_gqa = ( + inp.key.ndim > 3 and inp.key.stride(-2) == 0 and inp.key.shape[-2] > 1 + ) + if not mqa_or_gqa: + # With multiquery, cutlass is sometimes faster than decoder + # but it's not currently clear when. + priority_list_ops.appendleft(decoder.FwOp) + # Split-KV is useful with MQA + # for short Q-seqlen / long K-seqlen + if mqa_or_gqa and inp.query.shape[1] <= 32 and inp.key.shape[1] >= 256: + parallelism_BH = 0 # BMK + if inp.query.ndim == 3: + parallelism_BH = inp.query.shape[0] + elif inp.query.ndim == 4: # BMHK + parallelism_BH = inp.query.shape[0] * inp.query.shape[2] + elif inp.query.ndim == 5: # BMGHK + parallelism_BH = inp.query.shape[0] * inp.query.shape[2] + if parallelism_BH > 0 and parallelism_BH < 64: + priority_list_ops.appendleft(triton_splitk.FwOp) + # Without variable seqlen flash is fastest + if not isinstance(inp.attn_bias, attn_bias.BlockDiagonalMask): + priority_list_ops.remove(flash.FwOp) + priority_list_ops.appendleft(flash.FwOp) + + return priority_list_ops + + +def _dispatch_fw(inp: Inputs, needs_gradient: bool) -> Type[AttentionFwOpBase]: + """Computes the best operator for forward + + Raises: + NotImplementedError: if not operator was found + + Returns: + AttentionOp: The best operator for the configuration + """ + # return _run_priority_list( + # "memory_efficient_attention_forward", + # _dispatch_fw_priority_list(inp, needs_gradient), + # inp, + # ) + return flash.FwOp + + +def _is_cutlassB_faster_than_flash(inp: Inputs) -> bool: + return False + + +def _dispatch_bw(inp: Inputs) -> Type[AttentionBwOpBase]: + priority_list_ops: List[Type[AttentionBwOpBase]] = [ + flash.BwOp, + cutlass.BwOp, + # CUDA illegal memory issues, race conditions etc.. + # triton.BwOp, + # Deprecated + small_k.BwOp, + ] + # if _is_cutlassB_faster_than_flash(inp): + # priority_list_ops.remove(cutlass.BwOp) + # priority_list_ops.insert(0, cutlass.BwOp) + # return _run_priority_list( + # "memory_efficient_attention_backward", priority_list_ops, inp + # ) + return flash.BwOp diff --git a/pkgs/xformers/ops/fmha/flash.py b/pkgs/xformers/ops/fmha/flash.py new file mode 100644 index 0000000..d3e0626 --- /dev/null +++ b/pkgs/xformers/ops/fmha/flash.py @@ -0,0 +1,666 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from itertools import zip_longest +from typing import Any, List, Optional, Set, Tuple, Union +import os +import torch + +from ..common import _get_storage_base, get_operator, register_operator +from .attn_bias import ( + AttentionBias, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalMask, + BlockDiagonalMask, + LowerTriangularMask, +) +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + check_lastdim_alignment_stride1, +) +global enable_ixdnn +enable_ixdnn = os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') != '0' + +FLASH_VERSION = "0.0.0" +try: + try: + from ... import _C_flashattention # type: ignore[attr-defined] + from ..._cpp_lib import _build_metadata + + if _build_metadata is not None: + FLASH_VERSION = _build_metadata.flash_version + except ImportError: + import flash_attn + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention + + FLASH_VERSION = flash_attn.__version__ + flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) + if flash_ver_parsed < (2, 3): + raise ImportError("Requires 2.3 for sliding window support") + + # create library so that flash-attn goes through the PyTorch Dispatcher + _flash_lib = torch.library.Library("xformers_flash", "DEF") + + _flash_lib.define( + "flash_fwd(Tensor query, Tensor key, Tensor value, " + "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " + "int max_seqlen_q, int max_seqlen_k, " + "float p, float softmax_scale, " + "bool is_causal, int window_size, bool return_softmax, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)" + ) + + _flash_lib.define( + "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " + "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + "int max_seqlen_q, int max_seqlen_k, " + "float p, float softmax_scale, bool is_causal, int window_size, bool use_alibi, int alibi_mode, int imp_mode) -> (Tensor, Tensor, Tensor)" + ) + + def _flash_fwd( + query, + key, + value, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + is_causal, + window_size, + return_softmax, + use_alibi, + alibi_mode, + imp_mode + ): + if enable_ixdnn: + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + ) = _C_flashattention.fwd_ixdnn( + query, + key, + value, + None, # out + cu_seq_lens_q, + cu_seq_lens_k, + p, + is_causal, + return_softmax, + use_alibi, + alibi_mode, + imp_mode, + -1, -1, + None, # rng + ) + else: + if cu_seq_lens_q is None: + assert cu_seq_lens_k is None + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + ) = _C_flashattention.fwd( + query, + key, + value, + None, # out + p, + softmax_scale, + is_causal, + return_softmax, + use_alibi, + alibi_mode, + None, # rng + ) + else: + out = query.new_empty(query.shape[0], query.shape[1], value.shape[2]) + ( + out, + q_padded, + k_padded, + v_padded, + out_padded, + softmax_lse, + p, + ) = _C_flashattention.varlen_fwd( + query, + key, + value, + out, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + False, + is_causal, + return_softmax, + use_alibi, + alibi_mode, + None, + ) + return out, softmax_lse, q_padded, k_padded, v_padded, out_padded + + def _flash_bwd( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + is_causal, + window_size, + use_alibi, + alibi_mode, + imp_mode + ): + if enable_ixdnn: + _C_flashattention.bwd_ixdnn( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seq_lens_q, + cu_seq_lens_k, + p, + is_causal, + use_alibi, + alibi_mode, + imp_mode, + -1, -1, + None, + ) + else: + if cu_seq_lens_k is None: + assert cu_seq_lens_q is None + _C_flashattention.bwd( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + p, + softmax_scale, + is_causal, + use_alibi, + alibi_mode, + None, + ) + else: + _C_flashattention.varlen_bwd( + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + False, # zero_tensors + is_causal, + use_alibi, + alibi_mode, + None, + ) + return dq, dk, dv + + _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") + _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") +except ImportError: + pass + + +def _convert_input_format( + inp: Inputs, +) -> Tuple[Inputs, Optional[torch.Tensor], int, Optional[torch.Tensor], int]: + assert inp.query.ndim in [4, 5] + query, key, value = inp.query, inp.key, inp.value + batch = query.shape[0] + seqlen_q = query.shape[1] + seqlen_kv = key.shape[1] + head_dim_q = query.shape[-1] + head_dim_v = value.shape[-1] + + attn_bias = inp.attn_bias + if enable_ixdnn: + if query.shape[0] == 1: + cu_seqlen_q = torch.tensor([seqlen_q], dtype=torch.int32, device=query.device) + cu_seqlen_k = torch.tensor([seqlen_kv], dtype=torch.int32, device=key.device) + else: + cu_seqlen_q = torch.full((seqlen_q,), seqlen_q, dtype=torch.int32, device=query.device) + cu_seqlen_k = torch.full((seqlen_kv,), seqlen_kv, dtype=torch.int32, device=key.device) + max_seqlen_q = inp.query.shape[1] + max_seqlen_k = inp.key.shape[1] + else: + if isinstance(attn_bias, BlockDiagonalMask): + attn_bias.k_seqinfo.seqstart = attn_bias.k_seqinfo.seqstart.to( + inp.query.device, non_blocking=True + ) + attn_bias.q_seqinfo.seqstart = attn_bias.q_seqinfo.seqstart.to( + inp.query.device, non_blocking=True + ) + + cu_seqlen_k = attn_bias.k_seqinfo.seqstart + cu_seqlen_q = attn_bias.q_seqinfo.seqstart + max_seqlen_q = attn_bias.q_seqinfo.max_seqlen + max_seqlen_k = attn_bias.k_seqinfo.max_seqlen + else: + cu_seqlen_k = None + cu_seqlen_q = None + max_seqlen_q = inp.query.shape[1] + max_seqlen_k = inp.key.shape[1] + + if query.ndim == 5: # QGA + # Fold the group/head_in_group dimensions together + def fold(x): + # Either the head is replicated + if x.stride(3) == 0: + return x[:, :, :, 0] + # Or we reshape + return x.reshape( + [ + x.shape[0], + x.shape[1], + -1, + x.shape[4], + ] + ) + + query = fold(query) + key = fold(key) + value = fold(value) + # Optimize for MHA + if key.ndim == 4 and key.stride(2) == 0 and value.stride(2) == 0: + key = key[:, :, :1] + value = value[:, :, :1] + # Initially we have `query.shape = [batch, seqlen, head_dim_q]` + # We want format `[batch * seqlen, num_heads, head_dim_q]` + if cu_seqlen_k is not None and os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0': + query = query.reshape([batch * seqlen_q, -1, head_dim_q]) + key = key.reshape([batch * seqlen_kv, -1, head_dim_q]) + value = value.reshape([batch * seqlen_kv, -1, head_dim_v]) + new_inp = replace( + inp, + query=query, + key=key, + value=value, + ) + return new_inp, cu_seqlen_q, max_seqlen_q, cu_seqlen_k, max_seqlen_k + + +def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool: + return isinstance( + attn_bias, + ( + LowerTriangularMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalFromBottomRightMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + ), + ) + + +def _window_size(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> int: + if isinstance( + attn_bias, + (BlockDiagonalCausalLocalAttentionMask,), + ): + return attn_bias._window_size or 0 + if isinstance(attn_bias, BlockDiagonalCausalLocalAttentionFromBottomRightMask): + return attn_bias._window_size + return 0 + + +def _check_needs_no_topleft(d: Inputs, reasons: List[str]) -> None: + # Flash does not support TopLeft, so only allow causal masks with TopLeft + # if each batch element has equal number of queries and keys. + if isinstance(d.attn_bias, BlockDiagonalCausalMask): + # Flash does not support TopLeft, so only allow BlockDiagonalCausalMask + # if each batch element has equal number of queries and keys. + for k_start, q_start in zip_longest( + d.attn_bias.k_seqinfo.seqstart_py, d.attn_bias.q_seqinfo.seqstart_py + ): + if k_start != q_start: + reasons.append( + "Only support BlockDiagonalCausalMask if equal" + " numbers of keys and queries" + ) + break + elif isinstance(d.attn_bias, LowerTriangularMask): + if d.query.shape[1] != d.key.shape[1]: + reasons.append( + "Only support LowerTriangularMask if equal number of" "keys and queries" + ) + + +def _check_strides_for_bmghk(x: torch.Tensor, name: str, reasons: List[str]) -> None: + """ + We want to be able to collapse the G/H dimensions together + """ + if x.ndim == 5: + stride_g, stride_h = x.stride(2), x.stride(3) + if x.shape[2] == 1: + return + if x.shape[3] == 1 or stride_h == 0: + return + if stride_g != stride_h * x.shape[-2]: + reasons.append( + f"GQA is only supported when the G/H dimensions are contiguous\n" + f" {name}.stride: {x.stride()}\n" + f" {name}.shape : {list(x.shape)}" + ) + + +@register_operator +class FwOp(AttentionFwOpBase): + """Operator that computes memory-efficient attention using \ + `Flash-Attention `_ \ + implementation. + """ + + OPERATOR = get_operator("xformers_flash", "flash_fwd") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 256 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + LowerTriangularMask, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalLocalAttentionMask, + BlockDiagonalCausalLocalAttentionFromBottomRightMask, + BlockDiagonalCausalFromBottomRightMask, + } + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_DIFFERENT_VALUE_EMBED = False + SUPPORTS_BMGHK = True + NAME = f"flshattF@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + _check_strides_for_bmghk(d.query, "query", reasons) + _check_strides_for_bmghk(d.key, "key", reasons) + _check_strides_for_bmghk(d.value, "value", reasons) + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + return_softmax = False + out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + # no cumulative seqlen + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + ) = _convert_input_format(inp) + out, softmax_lse, q_padded, k_padded, v_padded, o_padded = cls.OPERATOR( + inp.query, + inp.key, + inp.value, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + _window_size(inp.attn_bias), + return_softmax, + inp.use_alibi, + inp.alibi_mode, + inp.imp_mode, + ) + out = out.reshape(out_shape) + ctx = Context(out=out, lse=softmax_lse, q_padded=q_padded, k_padded=k_padded, v_padded=v_padded, o_padded=o_padded) + if inp.p != 0.0: + ctx.op_bw = BwOp + return (out, ctx) + + @classmethod + # type: ignore + def operator_flop( + cls, + query, + key, + value, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + causal, + return_softmax, + ) -> int: + return cls.attn_operator_flop( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + causal=causal, + seqstart_k=cu_seq_lens_k, + seqstart_q=cu_seq_lens_q, + ) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_operator("xformers_flash", "flash_bwd") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + IS_DETERMINISTIC = False + SUPPORTS_BMGHK = False # NOTE: Don't forget to update fmha doc when changing this! + NAME = f"flshattB@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + MAX_HEADDIM_SM8x = 192 + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + + # In fbcode in mode/dev-nosan, we get nans from flash v2.1 if there + # is a strange embedding dimension. + # if K not in {8, 16, 32, 64, 128, 256}: + # reasons.append(f"Embed dim {K} not supported") + + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + if d.device.type == "cuda": + # Due to limited shared-memory, some GPUs are limited in head dimension + device_capability = torch.cuda.get_device_capability(d.device) + is_sm80_or_sm90 = device_capability in [(8, 0), (9, 0)] + if ( + max(d.key.shape[-1], d.query.shape[-1]) > cls.MAX_HEADDIM_SM8x + and not is_sm80_or_sm90 + ): + reasons.append( + "requires a GPU with compute capability 8.0 " + f"(A100) or 9.0 (H100) for 'query.shape[-1] > {cls.MAX_HEADDIM_SM8x}'" + ) + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + ) = _convert_input_format(inp) + assert ctx.lse.is_contiguous + ctx_lse = ctx.lse + if os.getenv('ENABLE_FLASH_ATTENTION_WITH_IXDNN', '1') == '0': + assert ctx_lse.shape[2] >= max_seqlen_q + if max_seqlen_q != ctx_lse.shape[2]: + ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous() + kernel_out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + + # Create dq,dk,dv + # If Q/K/V come from a single QKV tensor, let's put the gradient in the + # right strides, so we can avoid a `cat` + if ( + ctx.q_padded.shape[0] == ctx.k_padded.shape[0] + and ctx.q_padded.shape[-1] == ctx.v_padded.shape[-1] + and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded) + and _get_storage_base(ctx.q_padded) == _get_storage_base(ctx.k_padded) + ): + # Create one big contiguous chunk + # This is because q, k and v usually come from a single + # output of a linear layer that is chunked. + # Creating the gradients with the right layout saves us + # a `torch.cat` call in the backward pass + chunk = torch.empty( + (*ctx.q_padded.shape[0:-2], 3, ctx.q_padded.shape[-2], ctx.q_padded.shape[-1]), + dtype=ctx.q_padded.dtype, + device=inp.device, + ) + grads = Gradients( + dq=chunk.select(-3, 0), + dk=chunk.select(-3, 1), + dv=chunk.select(-3, 2), + ) + else: + grads = Gradients( + dq=torch.empty_like(ctx.q_padded), + dk=torch.empty_like(ctx.k_padded), + dv=torch.empty_like(ctx.v_padded), + ) + + assert grad.dtype in cls.SUPPORTED_DTYPES + cls.OPERATOR( + grad.reshape(kernel_out_shape).contiguous(), + ctx.q_padded, + ctx.k_padded, + ctx.v_padded, + # ctx.out.reshape(kernel_out_shape), + ctx.o_padded, + ctx_lse, + grads.dq, + grads.dk, + grads.dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + inp.p, + inp.scale_float, + _is_causal(inp.attn_bias), + _window_size(inp.attn_bias), + inp.use_alibi, + inp.alibi_mode, + inp.imp_mode, + ) + grads.dq = grads.dq[..., :dq_shape[-1]].reshape(dq_shape) # We could have padded the head dimension + grads.dk = grads.dk[..., :dk_shape[-1]].reshape(dk_shape) + grads.dv = grads.dv[..., :dv_shape[-1]].reshape(dv_shape) + + return grads + + @classmethod + # type: ignore + def operator_flop( + cls, + grad, + query, + key, + value, + out, + lse, + dq, + dk, + dv, + cu_seq_lens_q, + cu_seq_lens_k, + max_seq_len_q, + max_seq_len_k, + p, + softmax_scale, + causal, + ) -> int: + return cls.attn_operator_flop( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + causal=causal, + seqstart_k=cu_seq_lens_k, + seqstart_q=cu_seq_lens_q, + ) diff --git a/pkgs/xformers/ops/fmha/small_k.py b/pkgs/xformers/ops/fmha/small_k.py new file mode 100644 index 0000000..cfa9237 --- /dev/null +++ b/pkgs/xformers/ops/fmha/small_k.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List, Mapping, Optional, Set, Tuple, Union + +import torch + +from ..common import get_xformers_operator, register_operator +from .attn_bias import AttentionBias +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + bmk2bmhk, +) + + +def _bmhk2bmk_contiguous(tensor) -> torch.Tensor: + return ( + tensor.permute((0, 2, 1, 3)) + .contiguous() + .view([tensor.shape[0] * tensor.shape[2], tensor.shape[1], tensor.shape[3]]) + .contiguous() + ) + + +def _get_tensor_bias_bmk( + attn_bias: Optional[Union[torch.Tensor, AttentionBias]] +) -> Optional[torch.Tensor]: + if not isinstance(attn_bias, torch.Tensor): + assert attn_bias is None + return None + # BMK -> BMHK + if attn_bias.ndim == 4: + attn_bias = attn_bias.reshape([-1, *attn_bias.shape[2:]]) + return attn_bias + + +@register_operator +class FwOp(AttentionFwOpBase): + """An operator optimized for very small values of K (``K <= 32``) \ + and f32 pre-Ampere as it does not use TensorCores. + Only supports contiguous inputs in BMK format, so an extra reshape \ + or contiguous call might be done. + + :Deprecated: + + This operator is deprecated and should not be used in new code + """ + + OPERATOR = get_xformers_operator("efficient_attention_forward_small_k") + SUPPORTED_DEVICES = {"cuda"} + SUPPORTED_DTYPES = {torch.float} + SUPPORTED_MAX_K: float = 32 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {type(None), torch.Tensor} + SUPPORTS_DROPOUT = True + SUPPORTS_CUSTOM_SCALE = False + NAME = "smallkF" + + BACKWARD_ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 4e-3, + } + # as this kernel is a bit slow, this should make tests run faster + _TEST_BATCH_SIZES = [1, 3] + _TEST_K = [2, 3, 8, 16, 32] + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0: + reasons.append("bias with non-zero stride not supported") + buffer_size = 8 + k = d.query.shape[-1] + for pack in [1, 2, 4]: + if (k % pack) == 0 and (k // pack) <= buffer_size: + return reasons + reasons.append(f"unsupported embed per head: {k}") + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + if inp.scale is not None: + raise NotImplementedError("Unsupport custom scale") + num_heads = inp.query.shape[2] + query = _bmhk2bmk_contiguous(inp.query) + key = _bmhk2bmk_contiguous(inp.key) + value = _bmhk2bmk_contiguous(inp.value) + + out, lse, rng_seed, rng_offset = cls.OPERATOR( + query=query, + key=key, + value=value, + compute_logsumexp=needs_gradient, + attn_bias=_get_tensor_bias_bmk(inp.attn_bias), + p=inp.p, + ) + out = bmk2bmhk(out, num_heads) + lse = lse.reshape([lse.shape[0] // num_heads, num_heads, lse.shape[1]]) + if not needs_gradient: + return out, None + ctx = Context(out=out, lse=lse) + if inp.p != 0.0: + ctx.op_bw = BwOp + ctx.rng_state = torch.tensor( + [rng_seed, rng_offset], dtype=torch.int64, device="cpu" + ) + return out, ctx + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_xformers_operator("efficient_attention_backward_small_k") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + + # there is some extra precision loss in the CPU implementation due to an + # extra accumulation step in grad_q, which is not present in the CUDA + # implementation + ERROR_ATOL: Mapping[torch.dtype, float] = { + torch.float: 4e-3, + } + NAME = "smallkB" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + if isinstance(d.attn_bias, torch.Tensor) and d.attn_bias.stride(1) != 0: + reasons.append("bias with non-zero stride not supported") + buffer_size = 8 + k = d.query.shape[-1] + for pack in [1, 2, 4]: + if (k % pack) == 0 and (k // pack) <= buffer_size: + return reasons + reasons.append(f"unsupported embed per head: {k}") + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + num_heads = grad.shape[2] + grad = _bmhk2bmk_contiguous(grad) + query = _bmhk2bmk_contiguous(inp.query) + key = _bmhk2bmk_contiguous(inp.key) + value = _bmhk2bmk_contiguous(inp.value) + out = _bmhk2bmk_contiguous(ctx.out) + + rng_seed = rng_offset = 0 + if inp.p != 0.0: + if ( + ctx.rng_state is None + or ctx.rng_state.dtype != torch.int64 + or ctx.rng_state.device.type != "cpu" + or ctx.rng_state.shape != (2,) + ): + raise NotImplementedError(f"Invalid rng_state: {ctx.rng_state}") + rng_seed, rng_offset = ctx.rng_state.tolist() + grad_q, grad_k, grad_v = cls.OPERATOR( + grad, + query, + key, + value, + # LSE: BHM -> (BH)M + ctx.lse.reshape([-1, ctx.lse.shape[-1]]), + out, + _get_tensor_bias_bmk(inp.attn_bias), + inp.p, + rng_seed, + rng_offset, + ) + return Gradients( + dq=bmk2bmhk(grad_q, num_heads), + dk=bmk2bmhk(grad_k, num_heads), + dv=bmk2bmhk(grad_v, num_heads), + ) diff --git a/pkgs/xformers/ops/fmha/triton.py b/pkgs/xformers/ops/fmha/triton.py new file mode 100644 index 0000000..2d6e2a0 --- /dev/null +++ b/pkgs/xformers/ops/fmha/triton.py @@ -0,0 +1,201 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import replace +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple + +import torch + +from ... import _is_triton_available +from ..common import register_operator + +# This implementation needs pre-MLIR triton +# The BW pass is not stable/well tested +# And also does not have the latest improvements +if TYPE_CHECKING or (False and _is_triton_available()): + try: + from flash_attn.flash_attn_triton import ( + _flash_attn_backward, + _flash_attn_forward, + ) + except ImportError: + import importlib + import pathlib + import sys + import types + + def import_module_from_path(path: str) -> types.ModuleType: + """Import a module from the given path, w/o __init__.py""" + module_path = pathlib.Path(path).resolve() + module_name = module_path.stem # 'path/x.py' -> 'x' + spec = importlib.util.spec_from_file_location(module_name, module_path) # type: ignore + assert isinstance(spec, importlib.machinery.ModuleSpec) + module = importlib.util.module_from_spec(spec) # type: ignore + sys.modules[module_name] = module + assert isinstance(spec.loader, importlib.abc.Loader) + spec.loader.exec_module(module) + return module + + flash_attn = import_module_from_path( + "third_party/flash-attention/flash_attn/flash_attn_triton.py" + ) + _flash_attn_backward = flash_attn._flash_attn_backward + _flash_attn_forward = flash_attn._flash_attn_forward + + triton_flash_backward = _flash_attn_backward + triton_flash_forward = _flash_attn_forward +else: + triton_flash_backward = None + triton_flash_forward = None + +from .attn_bias import LowerTriangularMask +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + check_lastdim_alignment_stride1, +) + + +def _prepare_inputs(inp: Inputs) -> Inputs: + attn_bias = inp.attn_bias + if isinstance(attn_bias, torch.Tensor) and attn_bias.ndim == 3: + B = inp.query.shape[0] + h = attn_bias.shape[0] // B + attn_bias = attn_bias.reshape(B, h, attn_bias.shape[1], attn_bias.shape[2]) + + # Make sure that the last dimension is contiguous + query, key, value = [ + x if x.stride(-1) == 1 else x.contiguous() + for x in [inp.query, inp.key, inp.value] + ] + return replace(inp, attn_bias=attn_bias, query=query, key=key, value=value) + + +@register_operator +class FwOp(AttentionFwOpBase): + """Operator that computes memory-efficient attention using \ + `Tri Dao's `_ \ + implementation, based on + `Phil Tillet's code `_ + """ + + OPERATOR = triton_flash_forward + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = {torch.half, torch.bfloat16} + SUPPORTED_MAX_K = 128 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + LowerTriangularMask, + # TODO: backwards accuracy is failing for a few cases, perhaps we want to disable this for now. + # torch.Tensor, + } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + NAME = "tritonflashattF" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + check_lastdim_alignment_stride1(reasons, "key", d.key, 8) + check_lastdim_alignment_stride1(reasons, "value", d.value, 8) + if cls.OPERATOR is None: + reasons.append("triton is not available") + if d.device.type == "cuda": + # Has only been tested on 8.0 / 9.0. + # Fails on 7.5 with illegal memory access + if torch.cuda.get_device_capability(d.device) < (8, 0): + reasons.append( + "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" + ) + if _is_triton_available(): + import triton + + if triton.__version__ > "2.0.0": + reasons.append("Only work on pre-MLIR triton for now") + return reasons + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + inp = _prepare_inputs(inp) + + out, lse, softmax_scale = triton_flash_forward( + q=inp.query, + k=inp.key, + v=inp.value, + bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None, + softmax_scale=inp.scale_float, + causal=isinstance(inp.attn_bias, LowerTriangularMask), + ) + return out, Context(lse=lse, out=out) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = triton_flash_backward + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_ATTN_BIAS_TYPES = FwOp.SUPPORTED_ATTN_BIAS_TYPES + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + NAME = "tritonflashattB" + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + check_lastdim_alignment_stride1(reasons, "key", d.key, 8) + check_lastdim_alignment_stride1(reasons, "value", d.value, 8) + if cls.OPERATOR is None: + reasons.append("triton is not available") + if d.device.type == "cuda": + if torch.cuda.get_device_capability(d.device) != (8, 0): + reasons.append("requires A100 GPU") + if _is_triton_available(): + import triton + + if triton.__version__ > "2.0.0": + reasons.append("Only work on pre-MLIR triton for now") + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + inp = _prepare_inputs(inp) + + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + grads = Gradients( + dq=torch.empty_like(inp.query), + dk=torch.empty_like(inp.key), + dv=torch.empty_like(inp.value), + ) + cls.OPERATOR( + grad, + inp.query, + inp.key, + inp.value, + ctx.out, + ctx.get_padded_lse(128), + grads.dq, + grads.dk, + grads.dv, + bias=inp.attn_bias if isinstance(inp.attn_bias, torch.Tensor) else None, + softmax_scale=inp.scale_float, + causal=isinstance(inp.attn_bias, LowerTriangularMask), + ) + return grads diff --git a/pkgs/xformers/ops/fmha/triton_splitk.py b/pkgs/xformers/ops/fmha/triton_splitk.py new file mode 100644 index 0000000..dd7d295 --- /dev/null +++ b/pkgs/xformers/ops/fmha/triton_splitk.py @@ -0,0 +1,738 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple + +import torch + +from ..common import _has_triton21, register_operator +from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask +from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1 + + +def _strides(x: torch.Tensor, *stride_names: str): + assert x.ndim == len(stride_names) + return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)} + + +if TYPE_CHECKING or _has_triton21(): + import triton + import triton.language as tl + + from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs + + @triton.jit + def _fwd_kernel_splitK( + Q, + K, + V, + sm_scale, + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Seq_len, + stride_qz, + stride_qm, + stride_qg, + stride_qh, + stride_qk, + stride_kz, + stride_kn, + stride_kg, + stride_kh, + stride_kk, + stride_vz, + stride_vn, + stride_vg, + stride_vh, + stride_vk, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + Z, + N_CTX_Q, + N_CTX_K, + BLOCK_N_PER_SPLIT, + H: tl.constexpr, + G: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + BOUNDS_CHECKS_N: tl.constexpr, + USE_SEQ_LEN: tl.constexpr, + PACKED_PER_VAL: tl.constexpr = 1, + N_GROUPS: tl.constexpr = 1, + ): + """This kernel can accept non-quantized or int4-quantized keys/values. + PACKED_PER_VAL determines the quantization type: + - PACKED_PER_VAL == 1 means no quantization + - PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32) + For the quantized case K/V should be int32 tensors. + Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8. + Quantization coefficients are stored at the beginning of the row along the last dimension of K/V + So K[B, H, M, :] has a form + [ quant_coef0, quant_coef1, ...| + group0_quant_value0, group0_quant_value1,... | + group1_quant_value0, group1_quant_value1,...] + where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset. + + Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs + before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists. + See how FwOp.apply does it below. + """ + tl.static_assert( + (PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32)) + or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)), + f"Only 4-bit quantization is supported, K/V should have dtype int32 in " + f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}", + ) + tl.static_assert( + (((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8), + "Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.", + ) + + QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1 + PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS + D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS + + start_m = tl.program_id(0) + off_zhg = tl.program_id(1) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + splitk_idx = tl.program_id(2) + + lo = splitk_idx * BLOCK_N_PER_SPLIT + if USE_SEQ_LEN: + kv_len = tl.load(Seq_len + off_z) + else: + kv_len = N_CTX_K + hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len) + + Q_block_ptr = tl.make_block_ptr( + base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + + k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg + # Additional shift by 1 along the last dimension in the quantized case, since + # the first element along that dim contains packed quantization coefficients. + K_block_ptr = tl.make_block_ptr( + base=k_base + stride_kk * QUANTIZED * N_GROUPS, + shape=(PACKED_D_PER_GROUP, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(PACKED_D_PER_GROUP, BLOCK_N), + order=(0, 1), + ) + v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg + V_block_ptr = tl.make_block_ptr( + base=v_base + stride_vk * QUANTIZED * N_GROUPS, + shape=(hi, PACKED_D_PER_GROUP), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, PACKED_D_PER_GROUP), + order=(1, 0), + ) + + if QUANTIZED: + # Pointers to quantization coefficients. Even those they are 1D, + # we have to use block pointers, since usual pointers + # don't support boundary checks + K_scale_shift_block_ptr = tl.make_block_ptr( + base=k_base, + shape=(1, hi), + strides=(stride_kk, stride_kn), + offsets=(0, lo), + block_shape=(1, BLOCK_N), + order=(0, 1), + ) + V_scale_shift_block_ptr = tl.make_block_ptr( + base=v_base, + shape=(hi, 1), + strides=(stride_vn, stride_vk), + offsets=(lo, 0), + block_shape=(BLOCK_N, 1), + order=(1, 0), + ) + else: + K_scale_shift_block_ptr = None + V_scale_shift_block_ptr = None + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + + # Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs. + # That turns tensors annotated as the one below into lists of tensors of length N_GROUPS. + # This is a solution for Triton native lack of support for lists of tensors. + acc: "VAR_ARGS_ARRAY" # noqa: F821 + + for i in range(len(acc)): # noqa: F821 + acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821 + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(acc)): # noqa: F821 + q[i] = tl.load( # noqa: F821 + tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,) + ) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + k: "VAR_ARGS_ARRAY" # noqa: F821 + v: "VAR_ARGS_ARRAY" # noqa: F821 + for i in range(len(acc)): # noqa: F821 + k[i], v[i] = load_dequantize_k_v_group( # noqa: F821 + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N, + PACKED_PER_VAL, + PACKED_D_PER_GROUP, + Q.dtype.element_ty, + i, + ) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + for i in range(len(acc)): # noqa: F821 + qk += tl.dot(q[i], k[i]) # noqa: F821 + qk *= qk_scale + + # TODO: This is slow, and only needed at the last iteration. + # Maybe we can unroll the last iteration instead? + if BOUNDS_CHECKS_N: + qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf")) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + p = p.to(Q.dtype.element_ty) + + # -- scale and update acc -- + for i in range(len(acc)): # noqa: F821 + acc[i] *= alpha[:, None] # noqa: F821 + acc[i] += tl.dot(p, v[i]) # noqa: F821 + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + if PACKED_PER_VAL > 1: + K_scale_shift_block_ptr = tl.advance( + K_scale_shift_block_ptr, (0, BLOCK_N) + ) + V_scale_shift_block_ptr = tl.advance( + V_scale_shift_block_ptr, (BLOCK_N, 0) + ) + + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s, + shape=(N_CTX_Q, D_PER_GROUP), + strides=(stride_osk_m, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, D_PER_GROUP), + order=(1, 0), + ) + for i in range(len(acc)): # noqa: F821 + tl.store( + tl.advance(O_block_ptr, (0, i * D_PER_GROUP)), + acc[i], # noqa: F821 + boundary_check=(0,), + ) + # Write metadata for split-K reduction + Metadata_ptr = ( + Metadata + + off_zhg * stride_mzhg + + splitk_idx * stride_ms + + start_m * BLOCK_M + + tl.arange(0, BLOCK_M) + ) + tl.store(Metadata_ptr, m_i) + tl.store(Metadata_ptr + stride_m2, l_i) + + @triton.jit + def load_dequantize_k_v_group( + K_block_ptr, + V_block_ptr, + K_scale_shift_block_ptr, + V_scale_shift_block_ptr, + BOUNDS_CHECKS_N: tl.constexpr, + PACKED_PER_VAL: tl.constexpr, + PACKED_D_PER_GROUP: tl.constexpr, + dtype: tl.constexpr, + group_id: tl.constexpr, + ): + """Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading. + If quantization is group-wise, use group_id to advance the pointers to the current group. + """ + # Advance to the current quantization group + K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0)) + V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id)) + + # -- load k, v -- + k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()) + v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()) + + if PACKED_PER_VAL > 1: + # K/V are quantized, load quantization coefficients and dequantize + + K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0)) + V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id)) + + k_scale_shift = tl.load( + K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else () + ) + v_scale_shift = tl.load( + V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else () + ) + + k_scale, k_shift = cast_uint32_to_half2(k_scale_shift) + v_scale, v_shift = cast_uint32_to_half2(v_scale_shift) + v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype) + k_t = dequantize( + tl.trans(k), + tl.trans(k_scale), + tl.trans(k_shift), + PACKED_PER_VAL, + ).to(dtype) + k = tl.trans(k_t) + return k, v + + @triton.jit + def cast_uint32_to_half2(scale_shift): + """Extract two float16 packed into one int32""" + scale = scale_shift & 0xFFFF + shift = scale_shift >> 16 + scale = scale.to(tl.uint16).to(tl.float16, bitcast=True) + shift = shift.to(tl.uint16).to(tl.float16, bitcast=True) + return scale, shift + + @triton.jit + def dequantize( + x_, + scale, + shift, + PACKED_PER_VAL: tl.constexpr = 8, + ): + """PACKED_PER_VAL is the number of values packed into each element x_. + For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8. + """ + + # Axis along which offsets are applied matters here + # It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL) + # and expand K/V to that shape before applying offsets + # However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2 + # Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends + # on the implementation details which might change in the future. + # Ideally we would like to use tl.reshape, but it's not implemented yet. + # See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501 + + # x_ : (BLOCK_N, D // PACKED_PER_VAL) + # scale: (BLOCK_N, 1) + # offsets: (PACKED_PER_VAL,) + BLOCK_N: tl.constexpr = x_.shape[0] + BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1] + offsets = tl.arange(0, PACKED_PER_VAL) * 4 + quant_offset = ( + x_[:, None, :] >> offsets[None, :, None] + ) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL) + + quant_offset = tl.view( + quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL) + ) + # Trick - instead of converting int4 to float16 we view it as float16 + # and then multiply by 32768 * 512 == 2**24 + quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True) + quant_offset = (quant_offset * 32768.0).to(tl.float16) + scale_512 = scale * 512 + + dequant = quant_offset * scale_512 + shift + return dequant + + @triton.jit + def _splitK_reduce( + Out_splitK, # [B, H, split_k, Mq, K] + Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li] + Out, # [B, H, M, K] + LSE, # [B, H, M] + split_k, + stride_osk_zhg, + stride_osk_s, + stride_osk_m, + stride_osk_k, + stride_mzhg, + stride_m2, + stride_ms, + stride_mm, + stride_oz, + stride_oh, + stride_og, + stride_om, + stride_ok, + stride_lse_zhg, + stride_lse_m, + BLOCK_SIZE: tl.constexpr, + H: tl.constexpr, + G: tl.constexpr, + ): + off_zhg = tl.program_id(0) + off_z = off_zhg // (H * G) + off_h = (off_zhg // G) % H + off_g = off_zhg % G + off_m = tl.program_id(1) + + Out_splitK_ptr = ( + Out_splitK + + stride_osk_zhg * off_zhg + + stride_osk_m * off_m + + tl.arange(0, BLOCK_SIZE) + ) + Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m + m = tl.load(Metadata_ptr) + l_sum = tl.load(Metadata_ptr + stride_m2) + acc = tl.load(Out_splitK_ptr) + + for split_k_idx in range(1, split_k): + Metadata_ptr = Metadata_ptr + stride_ms + Out_splitK_ptr = Out_splitK_ptr + stride_osk_s + + m_k = tl.load(Metadata_ptr) + l_k = tl.load(Metadata_ptr + stride_m2) + acc_k = tl.load(Out_splitK_ptr) + + m_new = tl.maximum(m, m_k) + if m_k < m: + # Scale incoming values + alpha = tl.math.exp2(m_k - m_new) + acc_k = acc_k * alpha + l_k = l_k * alpha + else: + # Scale our values + alpha = tl.math.exp2(m - m_new) + acc = acc * alpha + l_sum = l_sum * alpha + + m = m_new + l_sum = l_sum + l_k + acc = acc + acc_k + + acc = acc / l_sum + Out_ptr = ( + Out + + stride_oz * off_z + + stride_oh * off_h + + stride_og * off_g + + stride_om * off_m + + tl.arange(0, BLOCK_SIZE) + ) + tl.store(Out_ptr, acc) + + l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m + tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504) + +else: + _fwd_kernel_splitK = None + _splitK_reduce = None + + +@register_operator +class FwOp(AttentionFwOpBase): + """Flash-Attention with Split-K. Supports fused int-4 K/V quantization. + Quantized path will be taken if input K/V have type int32. + + Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along + the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported. + Quantization coefficients (scale and shift) are represented as two + float16 constants per group, packed into int32. Quantization coefficients of + all groups are placed at the beginning of the row. So, if unquantized K/V have head + dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS + and dtype int32. + Pseudocode for dequantizing one row can look like: + group_size = D // 8 + for i in range(NUM_GROUPS): + group_start = NUM_GROUPS + i * group_size + group_quant = K[..., group_start: group_start + group_size] + scale, shift = unpack_int32_into_float16x2(group_quant[0]) + group_dequant = group_quant[..., 1:] * scale + shift + ... + + """ + + OPERATOR = _fwd_kernel_splitK + SUPPORTED_DEVICES = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES = { + torch.half, + torch.bfloat16, + } # Those are dtypes of Q. In the quantized case K/V has dtype int32 + SUPPORTED_MAX_K = 128 + SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = { + type(None), + BlockDiagonalCausalWithOffsetPaddedKeysMask, + } + SUPPORTS_DROPOUT = False + SUPPORTS_CUSTOM_SCALE = True + SUPPORTS_BMGHK = True + NAME = "triton_splitKF" + + SPLIT_K: Optional[int] = None + BLOCK_M = 16 + BLOCK_N = 64 + + NUM_GROUPS = 1 # Default quantization is row-wise + + @classmethod + def shape_not_supported_reasons( + cls, Mq: int, Mkv: int, K: int, Kv: int + ) -> List[str]: + reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv) + if K not in {16, 32, 64, 128}: + reasons.append(f"Embed dim {K} not supported") + return reasons + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(FwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + if d.key.dtype != torch.int32: + check_lastdim_alignment_stride1(reasons, "key", d.key, 8) + check_lastdim_alignment_stride1(reasons, "value", d.value, 8) + if cls.OPERATOR is None: + reasons.append("triton is not available") + if d.device.type == "cuda": + # Has only been tested on 8.0 / 9.0. + if torch.cuda.get_device_capability(d.device) < (8, 0): + reasons.append( + "requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4" + ) + + q_len = d.query.shape[1] + if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask): + seqinfo = d.attn_bias.q_seqinfo + if q_len != seqinfo.seqstart_py[-1]: + reasons.append( + f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}" + ) + q_len = seqinfo.min_seqlen + if q_len != seqinfo.max_seqlen: + reasons.append( + "Variable query len is not supported in the presence of causal mask." + ) + + if d.key.ndim in [4, 5] and d.key.shape[-2] != 1: + if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1: + reasons.append("multiquery is only supported with query seqlen=1") + + if d.attn_bias is not None and q_len > 1: + reasons.append( + "query with seqlen > 1 is not supported in the presence of causal mask" + ) + return reasons + + @classmethod + def get_split_k(cls, B: int, H: int, Mk: int) -> int: + """Heuristic for the number of splits""" + bh = B * H + split_k = max(Mk, 1024) // bh + max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128 + while split_k > 0 and Mk / split_k < max_chunk_size: + split_k = split_k // 2 + split_k = min(split_k, 64) + split_k = max(split_k, 1) + return split_k + + @classmethod + def apply( + cls, inp: Inputs, needs_gradient: bool + ) -> Tuple[torch.Tensor, Optional[Context]]: + attn_bias = inp.attn_bias + seq_len = None + q, k, v = inp.get_qkv_in_bmghk() + + if attn_bias is not None: + assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask) + # TODO: do we really need to do this cast? seems fishy but + # I just copied it from the decoder.py + attn_bias.k_seqinfo.to(inp.query.device) + attn_bias.q_seqinfo.to(inp.query.device) + seq_len = attn_bias.k_seqinfo.seqlen + B = len(seq_len) + G, H, Kq = q.shape[-3:] + Kkv = v.shape[-1] + + # assume kv has been padded + q = q.reshape(B, -1, G, H, Kq) + k = k.reshape(B, -1, G, H, Kkv) + v = v.reshape(B, -1, G, H, Kkv) + + # Transpose in the case of MQA/GQA + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + assert q.shape[1] == 1 + q = q.transpose(1, 3) + k = k[:, :, :, :1] + v = v[:, :, :, :1] + + if k.dtype == torch.int32: + # Quantized K/V + PACKED_PER_VAL = 8 + Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8 + else: + Lk = k.shape[-1] + PACKED_PER_VAL = 1 + + B, Mk, G, H, Kkv = k.shape + B, M, G, H, Kq = q.shape + assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}" + + BLOCK_M = cls.BLOCK_M + BLOCK_N = cls.BLOCK_N + if cls.SPLIT_K is not None: + split_k = cls.SPLIT_K + else: + # Use heuristics + split_k = cls.get_split_k(B, H, Mk) + + M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M + o_splitk = torch.empty( + [B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device + ) + metadata = torch.empty( + [B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device + ) + lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32) + grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k) + + num_warps = 2 + split_size = (Mk + split_k - 1) // split_k + use_seq_len = seq_len is not None + _fwd_kernel_splitK_unrolled = unroll_varargs( + _fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1 + ) + + _fwd_kernel_splitK_unrolled[grid]( + Q=q, + K=k, + V=v, + sm_scale=inp.scale_float, + Out_splitK=o_splitk, + Metadata=metadata, + Seq_len=seq_len, + **_strides(q, "qz", "qm", "qg", "qh", "qk"), + **_strides(k, "kz", "kn", "kg", "kh", "kk"), + **_strides(v, "vz", "vn", "vg", "vh", "vk"), + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + Z=B, + H=H, + G=G, + N_CTX_Q=M, + N_CTX_K=Mk, + BLOCK_N_PER_SPLIT=split_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len, + USE_SEQ_LEN=use_seq_len, + num_warps=num_warps, + num_stages=1, + PACKED_PER_VAL=PACKED_PER_VAL, + N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1, + ) + + if mqa_swap_seqlen_head: + out = torch.empty( + (B, H, G, M, Kq), device=q.device, dtype=q.dtype + ).transpose(1, 3) + else: + out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype) + + # Merge together + grid = (B * G * H, M, 1) + _splitK_reduce[grid]( + o_splitk, + metadata, + out, + lse, + split_k=split_k, + **_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"), + **_strides(metadata, "mzhg", "m2", "ms", "mm"), + **_strides(out, "oz", "om", "og", "oh", "ok"), + **_strides(lse, "lse_zhg", "lse_m"), + BLOCK_SIZE=out.shape[-1], + G=G, + H=H, + # TODO: Tune num_warps + ) + lse = lse.reshape([B, G, H, M]) + if mqa_swap_seqlen_head: + # H/M dimensions have been swapped + out = out.transpose(1, 3) + lse = lse.transpose(2, 3) + if inp.query.ndim == 4: + # BMGHK -> BMHK + assert G == 1 + out = out[:, :, 0] + lse = lse[:, 0] + + return out, Context(out=out, lse=lse) + + +class FwOp_S1(FwOp): + SPLIT_K = 1 + NAME = "triton_splitK1" + + +class FwOp_S2(FwOp): + SPLIT_K = 2 + NAME = "triton_splitK2" + + +class FwOp_S4(FwOp): + SPLIT_K = 4 + NAME = "triton_splitK4" + + +class FwOp_S8(FwOp): + SPLIT_K = 8 + NAME = "triton_splitK8" + + +class FwOp_S16(FwOp): + SPLIT_K = 16 + NAME = "triton_splitK16" + + +class FwOp_S32(FwOp): + SPLIT_K = 32 + NAME = "triton_splitK32" + + +class FwOp_S64(FwOp): + SPLIT_K = 64 + NAME = "triton_splitK64" + + +class FwOp_S128(FwOp): + SPLIT_K = 128 + NAME = "triton_splitK128" diff --git a/pkgs/xformers/ops/indexing.py b/pkgs/xformers/ops/indexing.py new file mode 100644 index 0000000..a6e7478 --- /dev/null +++ b/pkgs/xformers/ops/indexing.py @@ -0,0 +1,223 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Sequence + +import torch + +from .common import BaseOperator, get_xformers_operator, register_operator + + +@register_operator +class ScaledIndexAddFw(BaseOperator): + OPERATOR = get_xformers_operator("scaled_index_addF") + OPERATOR_CATEGORY = "indexing" + NAME = "scaled_index_addF" + + +@register_operator +class ScaledIndexAddBw(BaseOperator): + OPERATOR = get_xformers_operator("scaled_index_addB") + OPERATOR_CATEGORY = "indexing" + NAME = "scaled_index_addB" + + +@register_operator +class IndexSelect(BaseOperator): + OPERATOR = get_xformers_operator("index_select") + OPERATOR_CATEGORY = "indexing" + NAME = "index_select" + + +class _ScaledIndexAdd(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx, + input: torch.Tensor, + index: torch.Tensor, + source: torch.Tensor, + scaling: Optional[torch.Tensor], + alpha: float, + ) -> torch.Tensor: + ScaledIndexAddFw.OPERATOR( + output=input, # in-place + input=input, + source=source, + index=index, + source_scaling=scaling, + alpha=alpha, + ) + ctx.mark_dirty(input) + ctx.save_for_backward(index, scaling, source) + ctx.source_shape = source.shape + ctx.alpha = alpha + return input + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + index, scaling, source = ctx.saved_tensors + grad_source = torch.empty_like(grad_output[: index.shape[0]]) + grad_source_scaling = ( + torch.empty( + ctx.source_shape, + dtype=scaling.dtype, + device=scaling.device, + ) + if scaling is not None + else None + ) + ScaledIndexAddBw.OPERATOR( + grad_source=grad_source, + grad_source_scaling=grad_source_scaling, + grad_output=grad_output, + source=source, + index=index, + source_scaling=scaling, + alpha=ctx.alpha, + ) + if grad_source_scaling is not None: + grad_source_scaling = grad_source_scaling.sum((0, 1)) + return ( + grad_output, # input + None, # index + grad_source, # source + grad_source_scaling, # scaling + None, # alpha + ) + + +def scaled_index_add( + input: torch.Tensor, # [B, M, D] + index: torch.Tensor, # [Bi] - int64 + source: torch.Tensor, # [Bi, M, D] + scaling: Optional[torch.Tensor] = None, # [D] + alpha: float = 1.0, +) -> torch.Tensor: + """ + In-place scaling+index_add + + Indices in ``index`` are assumed to be unique + + :Note: + + The FW pass is done in-place (``input`` is modified) + + :Note: + + This is experimental and has only been optimized for a few shapes + + :Equivalent pytorch code: + + .. code-block:: python + + return torch.index_add(inp, dim=0, source=scaling * src, index=indices, alpha=alpha) + """ + return _ScaledIndexAdd.apply( + input, + index, + source, + scaling, + alpha, + ) + + +class _IndexSelectCat(torch.autograd.Function): + @staticmethod + # type: ignore + def forward( + ctx, + *args: torch.Tensor, + ) -> torch.Tensor: + assert len(args) % 2 == 0 + sources = args[: len(args) // 2] + indices = args[len(args) // 2 :] + output_shape = 0 + total_source_elements = 0 + for source, index in zip(sources, indices): + output_shape += index.shape[0] * source.shape[1] + total_source_elements += source.shape[0] * source.shape[1] + output = torch.empty( + [output_shape], dtype=sources[0].dtype, device=sources[0].device + ) + output_i = 0 + for source, index in zip(sources, indices): + elements_here = index.shape[0] * source.shape[1] + IndexSelect.OPERATOR( + output=output[output_i : output_i + elements_here].view( + [index.shape[0], source.shape[1]] + ), + source=source, + index=index, + ) + output_i += elements_here + ctx.save_for_backward(*indices) + ctx.total_source_elements = total_source_elements + ctx.source_shapes = [s.shape for s in sources] + return output + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + indices = ctx.saved_tensors + grad_sources = torch.zeros( + [ctx.total_source_elements], + dtype=grad_output.dtype, + device=grad_output.device, + ) + grad_sources_i = 0 + grad_output_i = 0 + gradients = [] + for source_shape, index in zip(ctx.source_shapes, indices): + grad_output_slice = grad_output[ + grad_output_i : grad_output_i + index.shape[0] * source_shape[1] + ].reshape([index.shape[0], source_shape[1]]) + grad_output_i += index.shape[0] * source_shape[1] + + gradient_source = grad_sources[ + grad_sources_i : grad_sources_i + source_shape[0] * source_shape[1] + ].reshape(source_shape) + grad_sources_i += source_shape[0] * source_shape[1] + + ScaledIndexAddFw.OPERATOR( + output=gradient_source.unsqueeze(1), + input=None, + source=grad_output_slice.unsqueeze(1), + index=index, + source_scaling=None, + alpha=1.0, + ) + gradients.append(gradient_source) + return (*gradients, *([None] * len(gradients))) + + +def index_select_cat( + sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor] +) -> torch.Tensor: + """ + Indices in ``index`` are assumed to be unique + + :Note: + + This is experimental and has only been optimized for a few shapes + + :Example: + + Given: + - ``sources[0]`` of shape ``[S0, D0]`` + - ``indices[0]`` of shape ``[I0]`` + - ``sources[1]`` of shape ``[S1, D1]`` + - ``indices[1]`` of shape ``[I1]`` + returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]`` + + :Equivalent pytorch code: + + .. code-block:: python + + return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0) + """ + return _IndexSelectCat.apply(*sources, *indices) diff --git a/pkgs/xformers/ops/rmsnorm.py b/pkgs/xformers/ops/rmsnorm.py new file mode 100644 index 0000000..b9bda32 --- /dev/null +++ b/pkgs/xformers/ops/rmsnorm.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from torch import nn + +from .. import _is_triton_available + + +def rms_norm(x, weight: Optional[torch.Tensor], eps: float = 1e-6): + """ + RMS Normalization along the last dimension. + + This is similar to torch.nn.functional.normalize but with eps being added + instead of max. + + Expects x contiguous of shape (..., dim), and returns normalized data + of the same shape. For each dim-length vector x, the result has + + x / sqrt( x*x.sum() + eps) + + If weights are included, they are a contiguous parameter of length dim + which multiplies the result. + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + """ + assert _is_triton_available() + from .triton.rmsnorm_kernels import _rms_norm_forward + + if torch.is_grad_enabled() and ( + x.requires_grad or (weight is not None and weight.requires_grad) + ): + raise ValueError("Gradients not supported.") + + return _rms_norm_forward(x, weight, eps) + + +def rms_norm_add( + x: torch.Tensor, y: torch.Tensor, weight: Optional[torch.Tensor], eps: float = 1e-6 +): + """ + An addition fused with rms_norm. + + z = rms_norm_add(x, y, weight, eps) + + is equivalent to + + x += y + z = rms_norm(x, weight, eps) + + where x, y and z are all contiguous. + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + """ + if torch.is_grad_enabled() and ( + x.requires_grad + or y.requires_grad + or (weight is not None and weight.requires_grad) + ): + raise ValueError("Gradients not supported.") + assert _is_triton_available() + from .triton.rmsnorm_kernels import _rms_norm_add_forward + + return _rms_norm_add_forward(x, y, weight, eps) + + +class RMSNorm(torch.nn.Module): + """ + RMS Normalization layer along the last dimension. + + This is similar to torch.nn.functional.normalize but with eps being added + instead of max. + + Expects contiguous input of shape (..., dim), and returns normalized data + of the same shape. For each dim-length vector x, the result has + + x / sqrt( x*x.sum() + eps) + + If weights are included, they are a parameter of length dim which multiplies + the result. + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + """ + + def __init__(self, dim: int, include_weight: bool = True, eps: float = 1e-6): + super().__init__() + self.eps = eps + if include_weight: + self.weight: Optional[nn.Parameter] = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, x: torch.Tensor): + return rms_norm(x, self.weight, self.eps) # type: ignore + + def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor): + """ + An addition fused with forward. + + z = layer.increment_and_forward_(x, y) + + is equivalent to + + x += y + z = layer(x) + """ + return rms_norm_add(x, y, self.weight, self.eps) # type: ignore diff --git a/pkgs/xformers/ops/rope_padded.py b/pkgs/xformers/ops/rope_padded.py new file mode 100644 index 0000000..8380097 --- /dev/null +++ b/pkgs/xformers/ops/rope_padded.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional, Tuple + +import torch + +from xformers.ops.fmha.attn_bias import ( # type: ignore + BlockDiagonalCausalWithOffsetPaddedKeysMask, +) + +from .. import _is_triton_available + + +def rope_padded( + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + attn_bias: BlockDiagonalCausalWithOffsetPaddedKeysMask, + *, + theta: float = 10000.0, + out_q: Optional[torch.Tensor] = None, + adjacents: bool = True, + internal_dtype: str = "", +): + """ + Performs RoPE (rotary embeddings) and kv-cache emplacement for a heterogeneous + batch for inference in the style given by + BlockDiagonalCausalWithOffsetPaddedKeysMask. + The batch is concatted along the sequence dimension, so the + actual dim-0 length of all tensors is 1. + + xq, xk and xv should be (1, slen, n_heads, dim), where + xq's n_heads can differ from xk and xv. + + This function places the roped xk in the right place in cache_k, and + xv (unmodified) in the right place in cache_v, and returns out_q + (the roped xq) such that things are ready to call + + xformers.ops.memory_efficient_attention( + out_q, cache_k, cache_v, attn_bias=attn_bias + ) + + This functionality is experimental. Its API might be changed without warnings. + Use it at your own risk. + + Arguments: + xq: tensor of queries to apply rope to + xk: tensor of keys to apply rope to + xv: tensor of values to copy into cache_v + cache_k: cache of keys, MODIFIED IN PLACE + cache_v: cache of values, MODIFIED IN PLACE + attn_bias: details the layout of caches. + Used to determine frequencies for the + RoPE calculation as well as the locations in cache_k and cache_v + to write to. Must be on the device. + adjacents: If True, the inputs are in adjacent pairs along the final dim axis. + This is like the released LLaMA model. + If False, the dim axis is split in two equal pieces. + I.e. the features are ordered with all the real parts before all + the imaginary parts. This matches HuggingFace, e.g. + https://github.com/huggingface/transformers/blob/ + f143037789288ba532dada934a118e648e715738/ + src/transformers/models/llama/modeling_llama.py#L126-L130 + internal_dtype: set to "f32" or "f64" to enforce dtype in the calculation + """ + if torch.is_grad_enabled() and ( + xq.requires_grad + or xk.requires_grad + or xv.requires_grad + or cache_k.requires_grad + or cache_v.requires_grad + or out_q is not None + ): + raise ValueError("Gradients not supported.") + assert _is_triton_available() + import triton + + from .triton.rope_padded_kernels import _rope_padded_kernel + + n_total_queries = attn_bias.q_seqinfo.seqstart_py[-1] + cache_length = attn_bias.k_seqinfo.seqstart_py[-1] + bsz, q_len, n_q_heads, dim = xq.shape + assert q_len == n_total_queries + if bsz != 1: + raise ValueError( + "Expected batch size dimension to be 1" "as batches should be concatenated." + ) + xk_shape = xk.shape + n_kv_heads = xk_shape[2] + if xk_shape != (1, n_total_queries, n_kv_heads, dim): + raise ValueError("unexpected k shape") + if xv.shape != (1, n_total_queries, n_kv_heads, dim): + raise ValueError("unexpected v shape") + if cache_k.shape != (1, cache_length, n_kv_heads, dim): + raise ValueError("unexpected cache_k length") + if cache_v.shape != (1, cache_length, n_kv_heads, dim): + raise ValueError("unexpected cache_v length") + + xq_stride = xq.stride() + xk_stride = xk.stride() + xv_stride = xv.stride() + cache_k_stride = cache_k.stride() + cache_v_stride = cache_v.stride() + if xq_stride[3] != 1: + raise ValueError("Each q head must be contiguous") + if xk_stride[3] != 1: + raise ValueError("Each k head must be contiguous") + if xv_stride[3] != 1: + raise ValueError("Each v head must be contiguous") + if cache_k_stride[3] != 1: + raise ValueError("Each cache_k head must be contiguous") + if cache_v_stride[3] != 1: + raise ValueError("Each cache_v head must be contiguous") + n_total_heads = n_q_heads + 2 * n_kv_heads + v_start = n_total_heads - n_kv_heads + k_start = n_q_heads + if out_q is None: + out_q = xq.new_empty(1, n_total_queries, n_q_heads, dim) + out_q_stride: Tuple[int, ...] = (0, n_q_heads * dim, dim, 1) + else: + if out_q.shape != xq.shape: + raise ValueError("Unexpected shape of out_q") + out_q_stride = out_q.stride() + if out_q_stride[3] != 1: + raise ValueError("Each out_q head must be contiguous") + + assert out_q is not None + + logical_bsz = len(attn_bias.q_seqinfo.seqstart_py) - 1 + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // xq.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(dim)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + device = xq.device + # Move these to the right device, like fmha does. + attn_bias.k_seqinfo.to(device) + attn_bias.q_seqinfo.to(device) + seqstartq = attn_bias.q_seqinfo.seqstart + seqstartk = attn_bias.k_seqinfo.seqstart + seqlenk = attn_bias.k_seqinfo.seqlen + assert internal_dtype in ["", "f32", "f64"] + # experiment with the order of dims here. + _rope_padded_kernel[(logical_bsz, attn_bias.q_seqinfo.max_seqlen, n_total_heads)]( + xq, + xk, + xv, + out_q, + cache_k, + cache_v, + seqstartq, + seqstartk, + seqlenk, + theta, + k_start, + v_start, + dim, + xq_stride[1], + xq_stride[2], + xk_stride[1], + xk_stride[2], + xv_stride[1], + xv_stride[2], + cache_k_stride[1], + cache_k_stride[2], + cache_v_stride[1], + cache_v_stride[2], + seqstartq.stride(0), + seqstartk.stride(0), + seqlenk.stride(0), + out_q_stride[1], + out_q_stride[2], + internal_dtype, + const_batch_strides=False, + cache_padding_length=0, + seqlenk_shift=0, + BLOCK_SIZE=BLOCK_SIZE, + adjacents=adjacents, + num_warps=num_warps, + ) + return out_q diff --git a/pkgs/xformers/ops/swiglu_op.py b/pkgs/xformers/ops/swiglu_op.py new file mode 100644 index 0000000..8c7781c --- /dev/null +++ b/pkgs/xformers/ops/swiglu_op.py @@ -0,0 +1,467 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Dict, Optional, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from .common import BaseOperator, get_xformers_operator, register_operator +from .unbind import stack_or_none, unbind + + +@register_operator +class DualGemmSiluOp(BaseOperator): + OPERATOR = get_xformers_operator("dual_gemm_silu_identity_mul") + OPERATOR_CATEGORY = "swiglu" + NAME = "dual_gemm_silu" + + @classmethod + # type: ignore + def operator_flop( + cls, x: torch.Tensor, w1: torch.Tensor, b1, w2: torch.Tensor, b2 + ) -> int: + """NOTE: we neglect the impact of biases / pointwises""" + M, N, K = x.shape[0], w1.shape[0], w1.shape[1] + return M * N * K * 2 * 2 + + +@register_operator +class GemmFusedSumOp(BaseOperator): + OPERATOR = get_xformers_operator("gemm_fused_operand_sum") + OPERATOR_CATEGORY = "swiglu" + NAME = "gemm_fused_operand_sum" + + @classmethod + # type: ignore + def operator_flop(cls, a: torch.Tensor, b: torch.Tensor, out1, out2) -> int: + M, N, K = a.shape[0], b.shape[1], a.shape[1] + return M * N * K * 2 + + +class _SwiGLUDecomposedFunc(torch.autograd.Function): + """ + This is just an example implementation with all + operations explicited. This implementation is worse + than pytorch, because pytorch is able to fuse some operations + (eg the linear forward ...) that are decomposed here. + + The time measurements were made on the ViT-Giant setting: + - A100/f16 + - input: [4440, 1536] + - hidden: [4440, 4096] + """ + + NAME = "decomposed" + FORCE_BW_F32 = False + + def _silu_backward(dy, x): + # https://github.com/pytorch/pytorch/blob/563b065f5a4b4055fa6b025c2514b566d5fd9439/aten/src/ATen/native/Activation.cpp#L483 + sigm = 1 / (1 + torch.exp(-x.float())) + return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype) + + # 952us + @classmethod + def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): + x1 = x @ w1.transpose(-2, -1) + b1 # 275us + x2 = x @ w2.transpose(-2, -1) + b2 # 275us + x3 = F.silu(x1) # 62us + x4 = x3 * x2 # 90us + x5 = x4 @ w3.transpose(-2, -1) + b3 # 250us + + ctx.save_for_backward(x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5) + return x5 + + # 1900us + @classmethod + def backward(cls, ctx, dx5): + saved_tensors = ctx.saved_tensors + if cls.FORCE_BW_F32: + dx5 = dx5.float() + saved_tensors = [t.float() for t in ctx.saved_tensors] + x, w1, b1, w2, b2, w3, b3, x1, x2, x3, x4, x5 = saved_tensors + dx4 = dx5 @ w3 # 255us (nn) + dw3 = dx5.transpose(-2, -1) @ x4 # 247us (nt) + db3 = dx5.sum(0) # 25us + dx3 = dx4 * x2 # 88us + dx2 = dx4 * x3 # 88us + dx1 = cls._silu_backward(dx3, x1) # 90us + dx = dx2 @ w2 # 260us (nn) + dw2 = dx2.transpose(-2, -1) @ x # 245us (nt) + db2 = dx2.sum(0) # 50us + dx += dx1 @ w1 # 260us (nn) + dw1 = dx1.transpose(-2, -1) @ x # 245us (nt) + db1 = dx1.sum(0) # 50us + return (dx, dw1, db1, dw2, db2, dw3, db3) + + +class _SwiGLUFusedFunc(torch.autograd.Function): + NAME = "fused.py" + + @classmethod + @torch.cuda.amp.custom_fwd + def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): + x1, x2, x4 = DualGemmSiluOp.OPERATOR(x, w1, b1, w2, b2) + + x5 = F.linear(x4, w3, b3) + ctx.save_for_backward(x, w1, w2, w3, x1, x2) + ctx.bias = [b1 is not None, b2 is not None, b3 is not None] + return x5 + + @staticmethod + def _linear_bw( + dy: torch.Tensor, x: torch.Tensor, bias: bool + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if not bias: + return (dy.transpose(-2, -1) @ x), None + db = torch.empty([dy.shape[1]], dtype=dy.dtype, device=dy.device) + dw = torch.empty([dy.shape[1], x.shape[1]], dtype=dy.dtype, device=dy.device) + GemmFusedSumOp.OPERATOR(dy.transpose(-2, -1), x, dw, db) + return dw, db + + @classmethod + @torch.cuda.amp.custom_bwd + def backward(cls, ctx, dx5): + x, w1, w2, w3, x1, x2 = ctx.saved_tensors + w1w2 = stack_or_none([w1, w2], dim=0) + + dx4 = dx5 @ w3 # 255us (nn) + dx1dx2, x4 = torch.ops.xformers.silu_bw_fused(x1, x2, dx4) + dx1, dx2 = dx1dx2.unbind(1) + del x1, x2, dx4 + + dw3, db3 = cls._linear_bw(dx5, x4, bias=ctx.bias[2]) + del x4, dx5 + if w1w2 is not None: + assert dx1dx2.is_contiguous() + assert w1w2.is_contiguous() + w1w2 = w1w2.view([w1.shape[0] * 2, w1.shape[1]]) + dx = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]) @ w1w2 + + # backward of linear1 + linear2 - packed + dw1dw2 = dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]).transpose(-2, -1) @ x + dw1dw2, db1db2 = cls._linear_bw( + dx1dx2.view([dx1.shape[0], 2 * dx1.shape[1]]), x, bias=ctx.bias[0] + ) + dw1, dw2 = dw1dw2.view([2, *w1.shape]).unbind(0) + if ctx.bias[0]: + db1db2 = db1db2.view([2, dx1.shape[1]]) + db1, db2 = torch.unbind(db1db2, dim=0) + else: + db1 = db2 = None + else: + dx = dx2 @ w2 # 260us (nn) + torch.addmm( + dx, dx1, w1.to(dx1.dtype), beta=1, alpha=1, out=dx + ) # dx += dx1 @ w1 + dw2, db2 = cls._linear_bw(dx2, x, bias=ctx.bias[1]) + dw1, db1 = cls._linear_bw(dx1, x, bias=ctx.bias[0]) + return (dx, dw1, db1, dw2, db2, dw3, db3) + + +class SwiGLUOp: + """Base class for any swiglu operator in :attr:`xformers.ops.swiglu`""" + + def __init__(self, op, packed_weights: bool, name: str, constraints): + self.NAME = name + self.PACKED_WEIGHTS = packed_weights + self.op = op + self.constraints = constraints + + def supports(self, op: "SwiGLUOpDispatch") -> bool: + if self.PACKED_WEIGHTS and not op.packed_weights: + return False + return all(c(op) for c in self.constraints) + + def __call__(self, *args: Optional[torch.Tensor]) -> torch.Tensor: + pass + + def __str__(self) -> str: + return f"SwiGLUOp:{self.NAME}" + + +class _ForwardToPythonAutogradFunc(SwiGLUOp): + def supports(self, op: "SwiGLUOpDispatch") -> bool: + # Let's disable autocast in bf16 until this issue is fixed + # https://github.com/pytorch/pytorch/issues/87979 + if op.dtype_autocast_gpu == torch.bfloat16: + return False + return super().supports(op) + + def __call__(self, *args, **kwargs): + return self.op.apply(*args, **kwargs) + + +class _ForwardToFunc(SwiGLUOp): + def __call__(self, *args, **kwargs): + return self.op(*args, **kwargs) + + def info(self): + if self.op.__name__ == "no_such_operator": + return "not built" + return "available" + + +def _eager_functional_swiglu( + x: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor, + w2: torch.Tensor, + b2: torch.Tensor, + w3: torch.Tensor, + b3: torch.Tensor, +) -> torch.Tensor: + x1 = F.linear(x, w1, b1) + x2 = F.linear(x, w2, b2) + hidden = F.silu(x1) * x2 + return F.linear(hidden, w3, b3) + + +@dataclass +class SwiGLUOpDispatch: + """Dispatcher to automatically select + the best operator in :attr:`xformers.ops.swiglu` + """ + + device: Union[torch.device, str] + dtype: torch.dtype + dtype_autocast_gpu: Optional[torch.dtype] + packed_weights: bool + bias_enabled: bool + + @property + def op(self) -> SwiGLUOp: + """Computes the best operator + + Returns: + SwiGLUOp: The best operator for the configuration + """ + priorities: Sequence[SwiGLUOp] = [ + SwiGLUPackedFusedOp, + SwiGLUFusedOp, + ] + for op in priorities: + if op.supports(self): + return op + return SwiGLUEagerOp + + @staticmethod + def from_arguments( + x: torch.Tensor, + w1: torch.Tensor, + b1: Optional[torch.Tensor], + w2: torch.Tensor, + b2: Optional[torch.Tensor], + w3: torch.Tensor, + b3: Optional[torch.Tensor], + ) -> "SwiGLUOpDispatch": + return SwiGLUOpDispatch( + device=x.device, + dtype=x.dtype, + packed_weights=stack_or_none((w1, w2), dim=0) is not None, + dtype_autocast_gpu=torch.get_autocast_gpu_dtype() + if torch.is_autocast_enabled() + else w1.dtype, + bias_enabled=b1 is not None and b2 is not None and b3 is not None, + ) + + +def _only_sm80(op: SwiGLUOpDispatch) -> bool: + device_type = op.device if isinstance(op.device, str) else op.device.type + return device_type == "cuda" and torch.cuda.get_device_capability(op.device)[0] >= 8 + + +def _only_half_or_autocast(op: SwiGLUOpDispatch) -> bool: + HALF_DTYPES = [torch.half, torch.bfloat16] + return op.dtype in HALF_DTYPES or ( + op.dtype_autocast_gpu is not None and op.dtype_autocast_gpu in HALF_DTYPES + ) + + +def _bias_enabled(op: SwiGLUOpDispatch) -> bool: + return op.bias_enabled + + +_SwiGLUDecomposedOp = _ForwardToPythonAutogradFunc( + _SwiGLUDecomposedFunc, False, "decomposed", constraints=[_bias_enabled] +) +SwiGLUFusedOp = _ForwardToPythonAutogradFunc( + _SwiGLUFusedFunc, False, "fused", constraints=[_only_sm80, _only_half_or_autocast] +) +SwiGLUPackedFusedOp = _ForwardToFunc( + get_xformers_operator("swiglu_packedw"), + True, + "fused.p.cpp", + constraints=[_only_sm80, _only_half_or_autocast], +) +SwiGLUEagerOp = _ForwardToFunc( + _eager_functional_swiglu, + False, + "eager", + constraints=[], +) + + +def _info() -> Dict[str, str]: + return {op.NAME: op.info() for op in [SwiGLUPackedFusedOp]} + + +def swiglu( + x: torch.Tensor, + w1: torch.Tensor, + b1: Optional[torch.Tensor], + w2: torch.Tensor, + b2: Optional[torch.Tensor], + w3: torch.Tensor, + b3: Optional[torch.Tensor], + *, + op: SwiGLUOp = None, +) -> torch.Tensor: + """ + Computes a SwiGLU block given the weights/bias of the 3 + linear layers. + + - It is recommended to keep ``op=None`` so the best implementation \ + available for the inputs will be used. + + + :Equivalent pytorch code: + + .. code-block:: python + + x1 = F.linear(x, w1, b1) + x2 = F.linear(x, w2, b2) + hidden = F.silu(x1) * x2 + return F.linear(hidden, w3, b3) + + :Packing weights: + + To allow faster implementations, it's recommended to have w1/w2 come from the same storage, as in: + .. code-block:: python + + w1, w2 = xformers.ops.unbind(w12, 0) + + :Supported hardware: + + This operator is only optimized on A100+ on ``torch.half`` or ``torch.bfloat16`` \ + (autocast is supported), and will fallback to a functional pytorch \ + implementation otherwise. + """ + + batch_shape = x.shape[:-1] + x = x.reshape([-1, x.shape[-1]]) + if w1.ndim != 2 or w1.shape != w2.shape: + raise ValueError(f"Invalid shapes for w1: {w1.shape} / w2: {w2.shape}") + if b1 is not None: + if b1.ndim != 1 or b1.shape[0] != w1.shape[0]: + raise ValueError(f"Invalid shapes for b1: {b1.shape}") + if b2 is not None: + if b2.ndim != 1 or b2.shape[0] != w2.shape[0]: + raise ValueError(f"Invalid shapes for b2: {b2.shape}") + if w3.ndim != 2 or w3.shape[1] != w2.shape[0]: + raise ValueError(f"Invalid shape for w3: {w3.shape}") + if b3 is not None: + if b3.ndim != 1 or b3.shape[0] != w3.shape[0]: + raise ValueError(f"Invalid shapes for w3: {w3.shape} / b3: {b3.shape}") + + if op is None: + op = SwiGLUOpDispatch.from_arguments(x, w1, b1, w2, b2, w3, b3).op + + if not op.PACKED_WEIGHTS: + return op(x, w1, b1, w2, b2, w3, b3).reshape([*batch_shape, -1]) + w1w2 = stack_or_none((w1, w2), dim=0) + if b1 is not None and b2 is not None: + b1b2: Optional[torch.Tensor] = stack_or_none((b1, b2), dim=0) + if b1b2 is None: + raise NotImplementedError("b1/b2 needs to be properly packed") + else: + b1b2 = None + assert b1 is None and b2 is None + + if w1w2 is None: + raise NotImplementedError("w1/w2 needs to be properly packed") + return op(x, w1w2, b1b2, w3, b3).reshape([*batch_shape, -1]) + + +class SwiGLU(nn.Module): + """ + A Module that encapsulates the call to :attr:`xformers.ops.swiglu`, + and holds the weights for the 3 linear layers + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: Optional[int] = None, + bias: bool = True, + *, + _pack_weights: bool = True, + ) -> None: + """Create a SwiGLU module + + Args: + in_features (int): Number of features of the input + hidden_features (int): Number of hidden features + out_features (Optional[int], optional): Number of features of the input. Defaults to None. + bias (bool, optional): Whether linear layers also include a bias. Defaults to True. + """ + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.w12: Optional[nn.Linear] + if _pack_weights: + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + else: + self.w12 = None + self.w1 = nn.Linear(in_features, hidden_features, bias=bias) + self.w2 = nn.Linear(in_features, hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + self.hidden_features = hidden_features + self.out_features = out_features + self.in_features = in_features + self.op = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes :attr:`swiglu` with the module's weights + + Args: + x (torch.Tensor): A Tensor of shape ``[..., in_features]`` + + Returns: + torch.Tensor: A Tensor of shape ``[..., out_features]`` + """ + return swiglu(x, *self._ordered_params(), op=self.op) + + def _ordered_params(self): + """Used for testing - returns ordered arguments for operators""" + b1: Optional[torch.Tensor] + b2: Optional[torch.Tensor] + if self.w12 is not None: + w1w2 = self.w12.weight + b1b2 = self.w12.bias + w1, w2 = unbind( + w1w2.view([2, w1w2.shape[0] // 2, w1w2.shape[1]]), + dim=0, + ) + if b1b2 is not None: + b1, b2 = unbind(b1b2.view([2, b1b2.shape[0] // 2]), dim=0) + else: + b1, b2 = None, None + else: + w1, w2 = self.w1.weight, self.w2.weight + b1, b2 = self.w1.bias, self.w2.bias + return [ + w1, + b1, + w2, + b2, + self.w3.weight, + self.w3.bias, + ] diff --git a/pkgs/xformers/ops/triton/__init__.py b/pkgs/xformers/ops/triton/__init__.py new file mode 100644 index 0000000..6677db0 --- /dev/null +++ b/pkgs/xformers/ops/triton/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pkgs/xformers/ops/triton/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/ops/triton/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..32d7eab Binary files /dev/null and b/pkgs/xformers/ops/triton/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/triton/__pycache__/rmsnorm_kernels.cpython-310.pyc b/pkgs/xformers/ops/triton/__pycache__/rmsnorm_kernels.cpython-310.pyc new file mode 100644 index 0000000..fe711c2 Binary files /dev/null and b/pkgs/xformers/ops/triton/__pycache__/rmsnorm_kernels.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/triton/__pycache__/rope_padded_kernels.cpython-310.pyc b/pkgs/xformers/ops/triton/__pycache__/rope_padded_kernels.cpython-310.pyc new file mode 100644 index 0000000..42e72b9 Binary files /dev/null and b/pkgs/xformers/ops/triton/__pycache__/rope_padded_kernels.cpython-310.pyc differ diff --git a/pkgs/xformers/ops/triton/rmsnorm_kernels.py b/pkgs/xformers/ops/triton/rmsnorm_kernels.py new file mode 100644 index 0000000..944dfc7 --- /dev/null +++ b/pkgs/xformers/ops/triton/rmsnorm_kernels.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import torch +import triton +import triton.language as tl + +if hasattr(tl, "libdevice"): + tl_math = tl.libdevice +else: + tl_math = tl.math + + +@triton.jit +def _rms_norm_kernel( + x_ptr, + h1_ptr, + w_ptr, + eps, + stride, + N_COLS, + BLOCK_SIZE: tl.constexpr, + INCLUDE_WEIGHT: tl.constexpr, +): + row = tl.program_id(0) + x_ptr += row * stride + h1_ptr += row * stride + + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + a = tl.load( + x_ptr + cols, mask=cols < N_COLS, other=0.0, eviction_policy="evict_last" + ).to(tl.float32) + _mean += a * a + rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + a = tl.load( + x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first" + ).to(tl.float32) + if INCLUDE_WEIGHT: + w = tl.load(w_ptr + cols, mask=mask) + tl.store(h1_ptr + cols, a * rstd * w, mask=mask) + else: + tl.store(h1_ptr + cols, a * rstd, mask=mask) + + +@triton.jit +def _rms_norm_add_kernel( + x_ptr, + y_ptr, + h1_ptr, + w_ptr, + eps, + stride, + N_COLS, + BLOCK_SIZE: tl.constexpr, + INCLUDE_WEIGHT: tl.constexpr, +): + row = tl.program_id(0) + x_ptr += row * stride + y_ptr += row * stride + h1_ptr += row * stride + + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + ax = tl.load( + x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last" + ).to(tl.float32) + ay = tl.load( + y_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first" + ).to(tl.float32) + a = ax + ay + tl.store(x_ptr + cols, a, mask=mask) + _mean += a * a + rstd = tl_math.rsqrt((tl.sum(_mean, axis=0) / N_COLS) + eps) + for offset in range(0, N_COLS, BLOCK_SIZE): + cols = offset + tl.arange(0, BLOCK_SIZE) + mask = cols < N_COLS + a = tl.load( + x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_first" + ).to(tl.float32) + if INCLUDE_WEIGHT: + w = tl.load(w_ptr + cols, mask=mask) + tl.store(h1_ptr + cols, a * rstd * w, mask=mask) + else: + tl.store(h1_ptr + cols, a * rstd, mask=mask) + + +def _rms_norm_forward(x, attn_norm_weights, eps): + if not x.is_contiguous(): + raise ValueError("data must be contiguous") + if attn_norm_weights is not None: + if not attn_norm_weights.is_contiguous(): + raise ValueError("weights must be contiguous") + out = torch.empty_like(x) + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + _rms_norm_kernel[(M,)]( + x_arg, + out, + attn_norm_weights, + eps, + x_arg.stride(0), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + INCLUDE_WEIGHT=attn_norm_weights is not None, + ) + return out + + +def _rms_norm_add_forward(x, y, attn_norm_weights, eps): + # x, y contiguous of same shape [..., n] + # output of same shape, normed over the last dim. + if not x.is_contiguous(): + raise ValueError("x must be contiguous") + if not y.is_contiguous(): + raise ValueError("y must be contiguous") + if attn_norm_weights is not None: + if not attn_norm_weights.is_contiguous(): + raise ValueError("weights must be contiguous") + out = torch.empty_like(x) + x_arg = x.reshape(-1, x.shape[-1]) + y_arg = y.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + _rms_norm_add_kernel[(M,)]( + x_arg, + y_arg, + out, + attn_norm_weights, + eps, + x_arg.stride(0), + N, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + INCLUDE_WEIGHT=attn_norm_weights is not None, + ) + return out diff --git a/pkgs/xformers/ops/triton/rope_padded_kernels.py b/pkgs/xformers/ops/triton/rope_padded_kernels.py new file mode 100644 index 0000000..7b1846f --- /dev/null +++ b/pkgs/xformers/ops/triton/rope_padded_kernels.py @@ -0,0 +1,161 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. +import triton # type: ignore +import triton.language as tl # type: ignore + +if hasattr(tl, "libdevice"): + tl_math = tl.libdevice +else: + tl_math = tl.math + + +@triton.jit +def _rope_padded_kernel( + xq, + xk, + xv, + out_q, + cache_k, + cache_v, + seqstartq, + seqstartk, + seqlenk, + theta, + k_start: tl.constexpr, + v_start: tl.constexpr, + dim: tl.constexpr, # dimension of each head + stride_xqM, + stride_xqH, + stride_xkM, + stride_xkH, + stride_xvM, + stride_xvH, + stride_cachekM, + stride_cachekH, + stride_cachevM, + stride_cachevH, + stride_seqstartq, + stride_seqstartk, + stride_seqlenk, + stride_outqM, + stride_outqH, + internal_dtype: tl.constexpr, + # If True, seqstartq and seqstartk are not used but rather we + # assume that every batch element has the same number of + # queries (i.e. num_queries := tl.num_programs(1) ) + # and the same cache space cache_padding_length. + # Always False when called below. + const_batch_strides: tl.constexpr, + # If const_batch_strides==True, the common cache length for each batch element. + # (Only the first seqlenk[i] elements are actually in use, and only the last + # num_queries of those are actually written to.) + cache_padding_length, + # offset added to all values in seqlenk before using them. + # Always 0 when called below. + seqlenk_shift: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + adjacents: tl.constexpr, +): + """ + Each letter in this diagram is a whole row of length dim. + + INPUT xq xk xv + + head_dim ─► + + batch qqqqqq kk vv + │ qqqqqq kk vv + ▼ qqqqqq kk vv + + head_idx: (goes across all heads of all 3 inputs) + ▲ ▲ ▲ ▲ ▲ ▲ + │ │ │ │ │ │ + │ │ + 0 k_start │v_start │n_total_heads + │ │ + │ │ + k_start v_start + + Output is to out_q (same shape as xq), an xk-shaped part + of cache_k and an xv-shaped part of cache_v + """ + batch_elt = tl.program_id(0) + query_pos_in_batch_elt = tl.program_id(1) + head_idx = tl.program_id(2) + + if internal_dtype == "f32": + theta = theta.to(tl.float32) + elif internal_dtype == "f64": + theta = theta.to(tl.float64) + + if const_batch_strides: + query_pos = query_pos_in_batch_elt + tl.num_programs(1) * batch_elt + end_query_pos = tl.num_programs(1) * (batch_elt + 1) + else: + query_pos = query_pos_in_batch_elt + tl.load( + seqstartq + batch_elt * stride_seqstartq + ) + end_query_pos = tl.load(seqstartq + (batch_elt + 1) * stride_seqstartq) + if query_pos >= end_query_pos: + return + + is_q = head_idx < k_start + is_v = head_idx >= v_start + + xq += query_pos * stride_xqM + head_idx * stride_xqH + out_q += query_pos * stride_outqM + head_idx * stride_outqH + + if const_batch_strides: + cache_start = cache_padding_length * batch_elt + else: + cache_start = tl.load(seqstartk + batch_elt * stride_seqstartk) + end_of_batch_elt_cache = ( + cache_start + tl.load(seqlenk + batch_elt * stride_seqlenk) + seqlenk_shift + ) + + cache_pos = end_of_batch_elt_cache - (end_query_pos - query_pos) + seq_pos = cache_pos - cache_start + cache_k += (head_idx - k_start) * stride_cachekH + cache_pos * stride_cachekM + xk += query_pos * stride_xkM + (head_idx - k_start) * stride_xkH + in_qk = tl.where(is_q, xq, xk) + out_qk = tl.where(is_q, out_q, cache_k) + + cache_v += (head_idx - v_start) * stride_cachevH + cache_pos * stride_cachevM + xv += query_pos * stride_xvM + (head_idx - v_start) * stride_xvH + + out = tl.where(is_v, cache_v, out_qk) + x_in = tl.where(is_v, xv, in_qk) + + for offset in range(0, dim // 2, BLOCK_SIZE // 2): + c = tl.arange(0, BLOCK_SIZE // 2) + powers = (offset + c) * 2.0 + if adjacents: + cols_re = (offset + c) * 2 + cols_im = cols_re + 1 + else: + cols_re = offset + c + cols_im = cols_re + dim // 2 + + mask = cols_im < dim + + re_x = tl.load(x_in + cols_re, mask=mask) + im_x = tl.load(x_in + cols_im, mask=mask) + # freqs = seq_pos / (theta ** (powers / dim)) + freqs = seq_pos * tl_math.pow(theta, powers / (-dim)) + sines = tl.sin(freqs) + cosines = tl.cos(freqs) + re_out = re_x * cosines - im_x * sines + im_out = im_x * cosines + re_x * sines + + re_out_ = tl.where(is_v, re_x, re_out) + im_out_ = tl.where(is_v, im_x, im_out) + if internal_dtype == "f64": + if re_x.dtype == tl.bfloat16: + # triton 2.0.0 crashes if you try to convert + # float64 directly to bfloat16, so make an intermediate step. + re_out_ = re_out_.to(tl.float32) + im_out_ = im_out_.to(tl.float32) + tl.store(out + cols_re, re_out_, mask=mask) + tl.store(out + cols_im, im_out_, mask=mask) diff --git a/pkgs/xformers/ops/unbind.py b/pkgs/xformers/ops/unbind.py new file mode 100644 index 0000000..e974d64 --- /dev/null +++ b/pkgs/xformers/ops/unbind.py @@ -0,0 +1,125 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Sequence, Tuple, Union + +import torch + +from .common import _get_storage_base + + +def get_stack_strides( + tensors: Sequence[torch.Tensor], dim: int +) -> Optional[Tuple[int, ...]]: + """ + If the tensors are already stacked on dimension :code:`dim`, \ + returns the strides of the stacked tensors. \ + Otherwise returns :code:`None`. + """ + if len(tensors) <= 1 or dim > tensors[0].ndim: + return None + + final_stride = [] + for i in range(tensors[0].ndim + 1): + if i == dim: + final_stride.append( + tensors[1].storage_offset() - tensors[0].storage_offset() + ) + continue + if i > dim: + i -= 1 + final_stride.append(tensors[0].stride(i)) + + storage_data_ptr: Optional[int] = None + for i, x in enumerate(tensors[1:]): + # Sanity checks + if x.shape != tensors[0].shape: + return None + if x.stride() != tensors[0].stride(): + return None + if ( + x.storage_offset() + != tensors[0].storage_offset() + (i + 1) * final_stride[dim] + ): + return None + if storage_data_ptr is None: + storage_data_ptr = _get_storage_base(tensors[0]) + # Actual storage check + if _get_storage_base(x) != storage_data_ptr: + return None + return tuple(final_stride) + + +def _stack_or_none_fw( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], + dim: int, +) -> Optional[torch.Tensor]: + strides = get_stack_strides(tensors, dim) + if strides is not None: + input_shape = list(tensors[0].shape) + input_shape.insert(dim, len(tensors)) + return tensors[0].as_strided(input_shape, strides) + return None + + +def _stack_fw( + tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], + dim: int, +) -> torch.Tensor: + out = _stack_or_none_fw(tensors, dim) + if out is None: + out = torch.stack(tensors, dim=dim) + return out + + +class _Unbind(torch.autograd.Function): + """ + See function `unbind` + """ + + @staticmethod + # type: ignore + def forward(ctx, x: torch.Tensor, dim: int): + ctx.dim = dim + return x.unbind(dim) + + @classmethod + # type: ignore + def backward(cls, ctx, *tensors: torch.Tensor): + return _stack_fw(tensors, ctx.dim), None + + +class _StackOrNone(torch.autograd.Function): + """ + See function `stack_or_none` + """ + + @staticmethod + # type: ignore + def forward(ctx, dim: int, *tensors: torch.Tensor): + ctx.dim = dim + return _stack_or_none_fw(tensors, dim=dim) + + @classmethod + # type: ignore + def backward(cls, ctx, grad: torch.Tensor): + return (None, *grad.unbind(dim=ctx.dim)) + + +def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: + """ + Does exactly the same as :attr:`torch.unbind` for the forward. + In backward, avoids a :attr:`torch.cat` if the gradients + are already multiple views of the same storage + """ + return _Unbind.apply(x, dim) + + +def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor: + """ + Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated + without any memory operation. Otherwise returns None. + """ + return _StackOrNone.apply(dim, *tensors) diff --git a/pkgs/xformers/profiler/__init__.py b/pkgs/xformers/profiler/__init__.py new file mode 100644 index 0000000..ae14c7d --- /dev/null +++ b/pkgs/xformers/profiler/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from .api import profile, step +from .profiler import MemSnapshotsProfiler, NsightProfiler, PyTorchProfiler +from .slow_ops_profiler import DetectSlowOpsProfiler + +__all__ = [ + "profile", + "step", + "MemSnapshotsProfiler", + "PyTorchProfiler", + "NsightProfiler", + "DetectSlowOpsProfiler", +] diff --git a/pkgs/xformers/profiler/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/profiler/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..b2a83de Binary files /dev/null and b/pkgs/xformers/profiler/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/profiler/__pycache__/api.cpython-310.pyc b/pkgs/xformers/profiler/__pycache__/api.cpython-310.pyc new file mode 100644 index 0000000..8232113 Binary files /dev/null and b/pkgs/xformers/profiler/__pycache__/api.cpython-310.pyc differ diff --git a/pkgs/xformers/profiler/__pycache__/device_limits.cpython-310.pyc b/pkgs/xformers/profiler/__pycache__/device_limits.cpython-310.pyc new file mode 100644 index 0000000..5426b75 Binary files /dev/null and b/pkgs/xformers/profiler/__pycache__/device_limits.cpython-310.pyc differ diff --git a/pkgs/xformers/profiler/__pycache__/profiler.cpython-310.pyc b/pkgs/xformers/profiler/__pycache__/profiler.cpython-310.pyc new file mode 100644 index 0000000..f0cc257 Binary files /dev/null and b/pkgs/xformers/profiler/__pycache__/profiler.cpython-310.pyc differ diff --git a/pkgs/xformers/profiler/__pycache__/slow_ops_profiler.cpython-310.pyc b/pkgs/xformers/profiler/__pycache__/slow_ops_profiler.cpython-310.pyc new file mode 100644 index 0000000..e6da4a8 Binary files /dev/null and b/pkgs/xformers/profiler/__pycache__/slow_ops_profiler.cpython-310.pyc differ diff --git a/pkgs/xformers/profiler/api.py b/pkgs/xformers/profiler/api.py new file mode 100644 index 0000000..37000be --- /dev/null +++ b/pkgs/xformers/profiler/api.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Optional, Sequence, Tuple + +import torch.nn as nn + +from .profiler import ( + MemSnapshotsProfiler, + NsightProfiler, + PyTorchProfiler, + PyTorchProfiler_CUDAOnly, + _Profiler, +) +from .slow_ops_profiler import DetectSlowOpsProfiler # noqa: F401 + +DEFAULT_SCHEDULE = ( + (MemSnapshotsProfiler, 0, 2), + (NsightProfiler, 4, 6), + (PyTorchProfiler, 6, 7), + (PyTorchProfiler_CUDAOnly, 7, 8), + # TODO: There are some issues in PyTorch stable + # which are now fixed on main, but might break this profiler + # https://github.com/pytorch/pytorch/issues/94403 + # (DetectSlowOpsProfiler, 9, 10), +) + + +def profile( + output_dir: str, + module: Optional[nn.Module] = None, + schedule: Sequence[Tuple[Any, int, int]] = DEFAULT_SCHEDULE, +): + """ + A pre-configured profiler that will run on the first ~20 steps of the training + It will provide multiple traces that can be exploited later. + Use it in a context manager around your training loop, and call `xformers.profiler.step` + before starting the next iteration. + + :Examples: + + .. code-block:: python + + import torch + import timm.models + import xformers.profiler + + dtype = torch.bfloat16 + device = "cuda" + model = timm.models.vit_large_patch16_224().to(device).to(dtype) + inp = torch.zeros([64, 3, 224, 224], device=device, dtype=dtype) + optim = torch.optim.Adam(model.parameters()) + + with xformers.profiler.profile( + output_dir="profile_data", + module=model, + schedule=[ + (MemSnapshotsProfiler, 0, 2), + (DetectSlowOpsProfiler, 2, 4), + (NsightProfiler, 4, 6), + (PyTorchProfiler, 6, 20), + ] + ): + for i in range(20): + model(inp).sum().backward() + optim.step() + optim.zero_grad() + xformers.profiler.step() + + # alternatively, use the profiler without context and with ``.start()`` / `.stop()` + # calls. + + xprofiler = xformers.profiler.profile(...) + xprofiler.start() + + for i in range(20): + model(inp).sum().backward() + optim.step() + optim.zero_grad() + xprofiler.step() + + xprofiler.stop() + """ + return _Profiler(output_dir=output_dir, schedule=schedule, module=module) + + +def step() -> None: + """See `xformers.profiler.profile`""" + # Silently return if no profiler is enabled + if _Profiler._CURRENT_PROFILER is None: + return + _Profiler._CURRENT_PROFILER.step() diff --git a/pkgs/xformers/profiler/device_limits.py b/pkgs/xformers/profiler/device_limits.py new file mode 100644 index 0000000..f2e9caf --- /dev/null +++ b/pkgs/xformers/profiler/device_limits.py @@ -0,0 +1,113 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import Mapping, Tuple + +import torch + + +@dataclass +class DeviceLimit: + name: str = "default" # pattern to match from `torch.cuda.get_device_name()` + source: str = "" + sm: Tuple[int, int] = (0, 0) + # bytes/s + gmem_bandwidth: float = math.inf + # dtype -> TFlop/s + gemm_tflops: Mapping[torch.dtype, float] = field(default_factory=dict) + + +# For f32, we assume we can use tf32 +DEVICE_LIMITS: Tuple[DeviceLimit, ...] = ( + DeviceLimit( + "H100", + "https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet", # noqa: E501 + sm=(9, 0), + gmem_bandwidth=3.35 * (1024**4), # NOTE: PCIe is 2 TB/s + gemm_tflops={ + torch.float64: 67, + # NOTE: NVIDIA gives all numbers "with 2:4 sparsity" + # but we want the full GEMM numbers + torch.float32: 989 // 2, + torch.float16: 1979 // 2, + torch.bfloat16: 1979 // 2, + torch.int8: 3958 // 2, + }, + ), + DeviceLimit( + "A100", + "https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/a100/pdf/nvidia-a100-datasheet-us-nvidia-1758950-r4-web.pdf", # noqa: E501 + sm=(8, 0), + gmem_bandwidth=2 * (1024**4), # NOTE: PCIe is 1.5 TB/s + gemm_tflops={ + torch.float64: 19.5, + torch.float32: 156, + torch.float16: 312, + torch.bfloat16: 312, + torch.int8: 624, + }, + ), + DeviceLimit( + "A30", + "https://www.nvidia.com/content/dam/en-zz/Solutions/data-center/products/a30-gpu/pdf/a30-datasheet.pdf", + sm=(8, 0), + gmem_bandwidth=933 * (1024**3), + gemm_tflops={ + torch.float64: 10.3, + torch.float32: 82, + torch.float16: 165, + torch.bfloat16: 165, + torch.int8: 330, + }, + ), + DeviceLimit( + "T4", + "https://www.nvidia.com/content/dam/en-zz/Solutions/Data-Center/tesla-t4/t4-tensor-core-datasheet-951643.pdf", + sm=(7, 5), + gmem_bandwidth=300 * (1024**3), + gemm_tflops={ + torch.float32: 8.1, + torch.float16: 65, + torch.int8: 130, + }, + ), + # Assuming SXM2 + DeviceLimit( + "V100", + "https://images.nvidia.com/content/technologies/volta/pdf/tesla-volta-v100-datasheet-letter-fnl-web.pdf", + sm=(7, 0), + gmem_bandwidth=900 * (1024**3), + gemm_tflops={ + torch.float64: 7.8, + torch.float32: 15.7, + torch.float16: 125, + }, + ), + DeviceLimit( + "P100", + "https://images.nvidia.com/content/tesla/pdf/nvidia-tesla-p100-datasheet.pdf", + sm=(6, 0), + gmem_bandwidth=732 * (1024**3), + gemm_tflops={ + torch.float64: 5.3, + torch.float32: 10.6, + torch.float16: 21.2, + }, + ), +) + + +def get_device_limits(device) -> DeviceLimit: + """Currently only implemented for GPUs""" + if device is not None and device.type == "cuda": + device_sm = torch.cuda.get_device_capability(device) + device_name = torch.cuda.get_device_name(device) + for lim in DEVICE_LIMITS: + if lim.sm == device_sm: + if lim.name in device_name: + return lim + return DeviceLimit() diff --git a/pkgs/xformers/profiler/profiler.py b/pkgs/xformers/profiler/profiler.py new file mode 100644 index 0000000..63770fb --- /dev/null +++ b/pkgs/xformers/profiler/profiler.py @@ -0,0 +1,363 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import queue +import socket +import weakref +from dataclasses import dataclass +from pathlib import Path +from typing import Any, List, Optional, Sequence, Tuple + +import torch.cuda.memory +import torch.cuda.nvtx +import torch.nn as nn +import torch.profiler +import torch.utils.hooks + +logger = logging.getLogger(__name__) + + +def _normalize_tuple(x): + if not isinstance(x, tuple): + return (x,) + return x + + +class NsightProfiler: + """Profiler that triggers start of NSight profiler. + + NOTE: you need to ensure that the script running this code actually is running with + ``nsys profile`` and also has a flag ``--capture-range=cudaProfilerApi`` so the + capturing is performed by this profiler during certain steps. + """ + + def __init__(self, main_profiler: "_Profiler") -> None: + self.main_profiler = main_profiler + # TODO figure out if there is a way to know if nsys is launched at this point + + def __enter__(self): + self.main_profiler._install_hooks() + torch.cuda.profiler.start() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.profiler.stop() + self.main_profiler._remove_hooks() + + def step(self) -> None: + pass + + +class PyTorchProfiler: + """Profiler which relies on native Pytorch profiling. Current setting of the profiler + captures traces, memory footprint and other info that could be read via TensorBoard. + """ + + ACTIVITIES = [ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ] + + def __init__(self, main_profiler: "_Profiler") -> None: + self.main_profiler = main_profiler + activities_str = "_".join(a.name for a in self.ACTIVITIES) + trace_handler = torch.profiler.tensorboard_trace_handler( + dir_name=str( + main_profiler.output_dir + / f"profile_{activities_str}_{main_profiler.done_steps:06}" + ), + worker_name=main_profiler.worker_name, + use_gzip=True, + ) + self.hta = torch.profiler.profile( + on_trace_ready=trace_handler, + profile_memory=True, + record_shapes=True, + with_stack=True, + activities=self.ACTIVITIES, + ) + self.done_steps = 0 + + def __enter__(self): + torch.cuda.synchronize() + self.hta.__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.cuda.synchronize() + self.hta.__exit__(exc_type, exc_val, exc_tb) + + def step(self) -> None: + self.hta.step() + self.done_steps += 1 + + +class PyTorchProfiler_CUDAOnly(PyTorchProfiler): + # This profiler does not profile the CPU-side of things + # so we expect it to have almost no overhead + ACTIVITIES = [torch.profiler.ProfilerActivity.CUDA] + + +class MemSnapshotsProfiler: + """Profiler that captures memory traces for allocation and deallocation of memory for + tensors. + """ + + def __init__(self, main_profiler: "_Profiler") -> None: + self.main_profiler = main_profiler + self.enabled = False + + @property + def _has_trace_plot(self) -> bool: + return hasattr(torch.cuda._memory_viz, "trace_plot") + + def __enter__(self): + if not self._has_trace_plot: + return + self.enabled = True + # TODO: This does not show the previous memory allocations + # We could at least have a placeholder with how much + # memory was allocated before + torch.cuda.memory._record_memory_history( + True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + # record stack information for the trace events + trace_alloc_record_context=True, + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self._has_trace_plot: + self.main_profiler.summary.append( + ("MemTrace", "(not available with your Pytorch version)") + ) + return + assert self.enabled + snapshot = torch.cuda.memory._snapshot() + torch.cuda.memory._record_memory_history(False) + # No data was recorded - avoids a `ValueError` in `trace_plot` + if all(len(t) == 0 for t in snapshot["device_traces"]): + self.main_profiler.summary.append(("MemTrace", "(no allocation recorded)")) + return + # Dump to disk + filename = self.main_profiler._create_output_filename("memory_trace_plot.html") + self.main_profiler.summary.append(("MemTrace", filename)) + with open(filename, "w+") as fd: + fd.write( + torch.cuda._memory_viz.trace_plot( + snapshot, device=None, plot_segments=False + ) + ) + + def step(self) -> None: + pass + + +@dataclass +class _ProfilerState: + cls: Any + iter_begin: int + iter_end: int + object: Any = None + + +class _Profiler: + _CURRENT_PROFILER = None + + def __init__( + self, + output_dir: str, + schedule: Sequence[Tuple[Any, int, int]], + module: Optional[nn.Module], + ) -> None: + self.check_schedule(schedule) + self.done_steps = 0 + self.output_dir = Path(output_dir).absolute() + self.output_dir.mkdir(exist_ok=True, parents=True) + self.worker_name = "" + if torch.distributed.is_initialized(): + self.worker_name = "{}_{}".format(socket.gethostname(), str(os.getpid())) + + self.module = weakref.ref(module if module is not None else nn.Module()) + self.parents = ["Global"] + self.hooks: List[torch.utils.hooks.RemovableHandle] = [] + self.hooks_refcount = 0 + self.profilers: List[_ProfilerState] = sorted( + [_ProfilerState(cls, begin, end) for cls, begin, end in schedule], + key=lambda x: x.iter_begin, + ) + self.last_step = self.profilers[-1].iter_end if self.profilers else 0 + self.summary: List[Tuple[str, str]] = [] + + def check_schedule(self, schedule: Sequence[Tuple[Any, int, int]]) -> None: + if len(schedule) == 0: + logger.warning( + "You specified empty schedule for profiling. No data will be captured." + ) + + pq: Any = queue.PriorityQueue() + for cls, begin, end in schedule: + assert ( + begin >= 0 + ), f"Begin step of profiler must be non-negative, found: {begin}" + assert end > 0, f"End step of profiler must be positive, found: {end}" + assert ( + begin < end + ), f"Start must be before the end, found: begin={begin} and end={end}" + + pq.put((begin, end)) + + prev_end = -1 + for begin, end in pq.queue: + assert begin >= prev_end, ( + "There is some overlapping in profiler scheduling. Please do not" + + " overlap profilers by step as they may affect each other. Schedule:" + + f" {schedule}" + ) + prev_end = end + + def update_profilers_on_step(self) -> None: + for p in self.profilers: + if p.iter_begin <= self.done_steps and self.done_steps < p.iter_end: + if p.object is None: + o = p.cls(self) + logging.info(f"Starting {p.cls.__name__} profiler...") + o.__enter__() + p.object = o + else: + p.object.step() + else: + if p.object is not None: + o = p.object + p.object = None + logging.info(f"Shutting down {p.cls.__name__} profiler...") + o.__exit__(None, None, None) + + def _create_output_filename(self, filename: str) -> Path: + """ + Returns where to write a file with desired filename. + Handles the case where we are in distributed settings, or when + we need to output the same file multiple times (eg if a profiler + runs for several steps) + """ + if self.worker_name != "": + file = Path(filename) + folder = self.output_dir / file.stem + folder.mkdir(parents=True, exist_ok=True) + return folder / f"{self.done_steps:06}_{self.worker_name}{file.suffix}" + return self.output_dir / f"{self.done_steps:06}_{filename}" + + def _install_hooks(self) -> None: + self.hooks_refcount += 1 + # Already installed + if self.hooks: + return + module = self.module() + if module is None: + return + for name, sub_mod in module.named_modules(): + if name == "": + continue + name = name.split(".")[-1] + self.hooks += [ + sub_mod.register_forward_pre_hook(self._enter_module_hook(name)), + sub_mod.register_forward_hook(self._exit_module_hook(name)), + ] + + def _remove_hooks(self) -> None: + self.hooks_refcount -= 1 + if self.hooks_refcount == 0: + for h in self.hooks: + h.remove() + + def _enter_module_hook(self, name): + class PopState(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grad_outs): + self._exit_module(name) + return grad_outs + + def f(module, inputs): + self._enter_module(name) + inputs = _normalize_tuple(inputs) + out = PopState.apply(*inputs) + return out + + return f + + def _exit_module_hook(self, name): + class PushState(torch.autograd.Function): + @staticmethod + def forward(ctx, *args): + if len(args) == 1: + return args[0] + return args + + @staticmethod + def backward(ctx, *grad_outs): + self._enter_module(name) + return grad_outs + + def f(module, inputs, outputs): + self._exit_module(name) + outputs = _normalize_tuple(outputs) + return PushState.apply(*outputs) + + return f + + def _enter_module(self, name) -> None: + self.parents.append(name) + torch.cuda.nvtx.range_push(name) + + def _exit_module(self, name) -> None: + torch.cuda.nvtx.range_pop() + assert self.parents[-1] == name + self.parents.pop() + + def start(self): + self.__enter__() + + def stop(self, exc_type=None, exc_val=None, exc_tb=None): + self.__exit__(exc_type, exc_val, exc_tb) + + def __enter__(self): + if _Profiler._CURRENT_PROFILER is not None: + raise ValueError("Only one xformers profiler can be active at a time") + _Profiler._CURRENT_PROFILER = self + self.update_profilers_on_step() + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + _Profiler._CURRENT_PROFILER = None + + for p in self.profilers: + if p.object is not None: + p.object.__exit__(exc_type, exc_val, exc_tb) + + def step(self) -> None: + """Signals the profiler that the next profiling step has started.""" + self.done_steps += 1 + + if self.done_steps <= self.last_step: + self.parents = ["Global"] + self.update_profilers_on_step() + if self.done_steps == self.last_step: + logger.info("xFormers profiler done. %s", self.format_summary()) + + def format_summary(self) -> str: + if len(self.summary) == 0: + return "" + pad_titles = max(len(title) for title, value in self.summary) + return "summary:\n" + "\n".join( + [f" {title.ljust(pad_titles)}: {value}" for title, value in self.summary] + ) diff --git a/pkgs/xformers/profiler/slow_ops_profiler.py b/pkgs/xformers/profiler/slow_ops_profiler.py new file mode 100644 index 0000000..5fd8318 --- /dev/null +++ b/pkgs/xformers/profiler/slow_ops_profiler.py @@ -0,0 +1,513 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import json +import math +from collections import defaultdict +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, List, Set, Tuple + +import torch.cuda.memory +import torch.cuda.nvtx +import torch.profiler +import torch.utils.hooks +from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily +from torch.utils._pytree import tree_map + +from ..ops.common import FUNC_TO_XFORMERS_OPERATOR +from .device_limits import get_device_limits +from .profiler import _Profiler + + +class TorchFuncMockNoDispatch: + """ + Wraps a method to call it without the custom + pytorch dispatcher + """ + + def __init__(self, pt_impl): + self.pt_impl = pt_impl + + def __get__(self, obj, c): + return partial(self, obj) + + def __call__(self, obj, *args, **kwargs): + with _pop_mode_temporarily(): + return self.pt_impl(obj, *args, **kwargs) + + +class DispatcherWithoutBrokenFuncs(TorchDispatchMode): + TENSOR_FUNCS_NO_DISPATCH = [ + # Can't convert Stream argument to Python object + # https://github.com/pytorch/pytorch/issues/94403 + "record_stream" + ] + + def __enter__(self) -> None: + self._pt_impls = {} + for k in self.TENSOR_FUNCS_NO_DISPATCH: + impl = getattr(torch.Tensor, k) + self._pt_impls[k] = impl + setattr(torch.Tensor, k, TorchFuncMockNoDispatch(impl)) + return super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + for k in self.TENSOR_FUNCS_NO_DISPATCH: + setattr(torch.Tensor, k, self._pt_impls[k]) + return super().__exit__(exc_type, exc_val, exc_tb) + + +def get_shape(i): + return i.shape + + +def prod(x): + res = 1 + for i in x: + res *= i + return res + + +class GemmOpComputeFlops: + def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: + return (prod(inputs[0].shape[:-1]), inputs[1].shape[1], inputs[0].shape[-1]) + + def __call__(self, inputs: List[Any], outputs: List[Any]) -> float: + return 2 * prod(self._get_mnk(inputs)) + + def op_suffix(self, inputs: List[Any]) -> str: + m, n, k = self._get_mnk(inputs) + return f"_{m}x{n}x{k}" + + +class GemmOpComputeFlopsLinear(GemmOpComputeFlops): + def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: + return (prod(inputs[0].shape[:-1]), inputs[1].shape[0], inputs[0].shape[-1]) + + +class GemmOpComputeFlopsMv(GemmOpComputeFlops): + def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: + return (prod(inputs[0].shape[:-1]), 1, inputs[0].shape[-1]) + + +class GemmOpComputeFlopsBmm(GemmOpComputeFlops): + def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: + a, b = inputs[0], inputs[1] + assert a.ndim == 3 + assert b.ndim == 3 + bs = max(inputs[0].shape[0], inputs[1].shape[0]) + return (bs * a.shape[1], b.shape[-1], b.shape[-2]) + + +class GemmOpComputeFlopsAddmm(GemmOpComputeFlops): + def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: + return super()._get_mnk(inputs[1:]) + + +class GemmOpComputeFlopsAddbmm(GemmOpComputeFlopsBmm): + def _get_mnk(self, inputs: List[Any]) -> Tuple[int, int, int]: + return super()._get_mnk(inputs[1:]) + + +def conv_flop_count( + x_shape: List[int], + w_shape: List[int], + out_shape: List[int], + transposed: bool = False, +) -> float: + """ + Count flops for convolution. Note only multiplication is + counted. Computation for addition and bias is ignored. + Flops for a transposed convolution are calculated as + flops = (x_shape[2:] * prod(w_shape) * batch_size). + Args: + x_shape (list(int)): The input shape before convolution. + w_shape (list(int)): The filter shape. + out_shape (list(int)): The output shape after convolution. + transposed (bool): is the convolution transposed + Returns: + int: the number of flops + """ + batch_size = x_shape[0] + conv_shape = (x_shape if transposed else out_shape)[2:] + flop = batch_size * prod(w_shape) * prod(conv_shape) + return flop + + +def conv_flop(inputs: List[Any], outputs: List[Any]): + """ + Count flops for convolution. + """ + x, w = inputs[:2] + x_shape, w_shape, out_shape = (get_shape(x), get_shape(w), get_shape(outputs[0])) + transposed = inputs[6] + + return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) + + +def transpose_shape(shape): + return [shape[1], shape[0]] + list(shape[2:]) + + +def conv_backward_flop(inputs: List[Any], outputs: List[Any]): + grad_out_shape, x_shape, w_shape = [get_shape(i) for i in inputs[:3]] + output_mask = inputs[-1] + fwd_transposed = inputs[7] + flop_count = 0.0 + + if output_mask[0]: + grad_input_shape = get_shape(outputs[0]) + flop_count += conv_flop_count( + grad_out_shape, w_shape, grad_input_shape, not fwd_transposed + ) + if output_mask[1]: + grad_weight_shape = get_shape(outputs[1]) + flop_count += conv_flop_count( + transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed + ) + + return flop_count + + +def tensor_storage_size_in_mem(x: torch.Tensor): + total = 1 + for dim_sz, stride in zip(x.shape, x.stride()): + if stride >= 1: + total *= dim_sz + return total + + +def get_size(inputs: List[Any]): + total_bytes = 0 + + def process(x) -> None: + nonlocal total_bytes + if isinstance(x, torch.Tensor): + total_bytes += tensor_storage_size_in_mem(x) * x.element_size() + + tree_map(process, inputs) + return total_bytes + + +def operation_memory_rw_bytes(inputs: List[Any], outputs: List[Any]): + size_input, size_output = get_size(inputs), get_size(outputs) + return size_input + size_output + + +def output_read_from_input(inputs: List[Any], outputs: List[Any]): + size_input, size_output = get_size(inputs), get_size(outputs) + return size_output + min(size_input, size_output) + + +def output_total_size(inputs: List[Any], outputs: List[Any]): + return get_size(outputs) + + +def input_total_size(inputs: List[Any], outputs: List[Any]): + return get_size(inputs) + + +def guess_flops_unknown_op(inputs: List[Any], outputs: List[Any]): + # Approximation that isn't too bad + total_elements = 0 + + def process(x) -> None: + nonlocal total_elements + if isinstance(x, torch.Tensor): + total_elements += x.numel() + + tree_map(process, inputs) + tree_map(process, outputs) + return total_elements / 2 + + +def no_flop(inputs: List[Any], outputs: List[Any]): + return 0 + + +def no_io(inputs: List[Any], outputs: List[Any]): + return 0 + + +aten = torch.ops.aten +NO_FLOPS_NO_IO_OPS = [ + aten.permute, + aten.view, + aten.view_as, + aten.detach, + aten.t, + aten.transpose, + aten.expand, + aten._unsafe_view, + aten.select, + aten.split, + aten.split_with_sizes, + aten.empty, + aten.empty_strided, + aten.empty_like, + aten.is_same_size, +] +NO_FLOPS_OPS = [ + aten._reshape_alias, + aten.reshape, + aten.clone, + aten.cat, + aten.select_backward, + aten.slice, + aten.slice_backward, + aten.ones, + aten.ones_like, + aten.zeros_like, + aten.zero_, + aten.zeros, + aten.masked_fill, + aten.masked_fill_, +] + +flop_mapping = { + aten.mv: GemmOpComputeFlopsMv(), # mat-vec + aten.mm: GemmOpComputeFlops(), + aten.matmul: GemmOpComputeFlops(), + aten.addmm: GemmOpComputeFlopsAddmm(), + aten.bmm: GemmOpComputeFlopsBmm(), + aten.addbmm: GemmOpComputeFlopsAddbmm(), + aten.linear: GemmOpComputeFlopsLinear(), + aten.convolution: conv_flop, + aten._convolution: conv_flop, + aten.convolution_backward: conv_backward_flop, + # Operations with 0 flop + **{op: no_flop for op in NO_FLOPS_OPS}, + **{op: no_flop for op in NO_FLOPS_NO_IO_OPS}, +} +io_mapping = { + aten.clone: output_read_from_input, + aten.cat: output_read_from_input, + aten.slice: output_read_from_input, + aten.ones_like: output_total_size, + aten.zeros_like: output_total_size, + aten.zero_: input_total_size, + **{op: no_io for op in NO_FLOPS_NO_IO_OPS} + # TODO: Check how this is implemented in PT + # aten.slice_backward: no_flop, + # aten.select_backward: no_flop, +} + + +@dataclass +class _OpInfo: + flop_count: float = 0.0 + time_ms: float = 0.0 + io_bytes: int = 0 + is_exact_flop: bool = True + op_name: str = "" + op_suffix: str = "" + stacktrace: Tuple[str, ...] = field(default_factory=tuple) + ev_start: torch.cuda.Event = field( + default_factory=lambda: torch.cuda.Event(enable_timing=True) + ) + ev_end: torch.cuda.Event = field( + default_factory=lambda: torch.cuda.Event(enable_timing=True) + ) + + # Hardware limits for this operation (inf if unknown) + hardware_tflops_limit: float = math.inf + hardware_membw_limit: float = math.inf + + @property + def time_membound_ms(self) -> float: + assert self.time_ms > 0.0 + if self.io_bytes == 0: + return 0.0 + return min(self.time_ms, 1000 * self.io_bytes / self.hardware_membw_limit) + + @property + def time_computebound_ms(self) -> float: + assert self.time_ms > 0.0 + tflop = self.flop_count / (1000**4) + if tflop == 0.0: + return 0.0 + return min(self.time_ms, 1000 * tflop / self.hardware_tflops_limit) + + def finalize(self) -> None: + self.time_ms = self.ev_start.elapsed_time(self.ev_end) + + +@dataclass +class _OpInfoAggregated: + is_exact_flop: bool = True + total_flop_count: float = 0.0 + total_io_bytes: int = 0 + total_time_ms: float = 0.0 + total_time_membound_ms: float = 0.0 + total_time_computebound_ms: float = 0.0 + num: int = 0 + stacktraces: List[Tuple[str, ...]] = field(default_factory=list) + + def add(self, op: _OpInfo) -> None: + self.total_flop_count += op.flop_count + self.total_time_ms += op.time_ms + self.total_io_bytes += op.io_bytes + self.total_time_membound_ms += op.time_membound_ms + self.total_time_computebound_ms += op.time_computebound_ms + self.num += 1 + self.is_exact_flop = op.is_exact_flop + self.stacktraces.append(op.stacktrace) + + def as_dict(self, **kwargs) -> Dict[str, Any]: + mem_bound = min(1, self.total_time_membound_ms / self.total_time_ms) + tflops = self.total_flop_count / (self.total_time_ms / 1000) / (1000**4) + compute_bound = min(1, self.total_time_computebound_ms / self.total_time_ms) + return { + "is_exact_flop": self.is_exact_flop, + "total_flop_count": self.total_flop_count, + "total_time_ms": self.total_time_ms, + "total_io_bytes": self.total_io_bytes, + "num": self.num, + "Tflops": tflops, + "mem_bound": mem_bound, + "compute_bound": compute_bound, + **kwargs, + } + + +class DetectSlowOpsProfiler(DispatcherWithoutBrokenFuncs): + """ + Inspired from https://fb.workplace.com/groups/pytorch.dev/permalink/1054537595124720/ + """ + + def __init__(self, main_profiler: _Profiler) -> None: + self.main_profiler = main_profiler + self.trace: List[_OpInfo] = [] + self.temp_disabled = False + + def _hardware_tflops_membw_limit( + self, args: Tuple[Any, ...], outputs: Tuple[Any, ...] + ) -> Tuple[float, float]: + device = None + dtypes: List[torch.dtype] = [] + for a in itertools.chain(outputs, args): + if isinstance(a, torch.Tensor): + if device is None: + device = a.device + dtypes.append(a.dtype) + limits = get_device_limits(device) + dtypes = [dt for dt in dtypes if dt in limits.gemm_tflops] + if not dtypes or device is None: + return (math.inf, math.inf) + dtype = dtypes[0] + if torch.is_autocast_enabled() and dtype is torch.float32: + dtype = torch.get_autocast_gpu_dtype() + return limits.gemm_tflops[dtype], limits.gmem_bandwidth + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + func_packet = func._overloadpacket + if self.temp_disabled or func_packet.__name__ in [ + "_record_function_exit", + "_record_function_enter_new", + ]: + return func(*args, **kwargs) + + op = _OpInfo() + op.ev_start.record() + out = func(*args, **kwargs) + op.ev_end.record() + + ( + op.hardware_tflops_limit, + op.hardware_membw_limit, + ) = self._hardware_tflops_membw_limit( + args, out if isinstance(out, tuple) else (out,) + ) + op.op_name = func_packet.__name__ + # Prevent functions called by flop counting ops to be recorded + self.temp_disabled = True + flop_count = -1 + compute_flops = None + if func_packet in FUNC_TO_XFORMERS_OPERATOR: + flop_count = FUNC_TO_XFORMERS_OPERATOR[func_packet].operator_flop( + *args, **kwargs + ) + if flop_count == -1: + compute_flops = flop_mapping.get(func_packet, guess_flops_unknown_op) + flop_count = compute_flops(args, out if isinstance(out, tuple) else (out,)) + if isinstance(compute_flops, GemmOpComputeFlops): + op.op_name += compute_flops.op_suffix(args) + + compute_io = io_mapping.get(func_packet, operation_memory_rw_bytes) + op.io_bytes = compute_io(args, out if isinstance(out, tuple) else (out,)) + self.temp_disabled = False + + op.stacktrace = tuple(self.main_profiler.parents) + op.flop_count = flop_count + op.is_exact_flop = compute_flops is not guess_flops_unknown_op + self.trace.append(op) + + return out + + def __enter__(self): + self.main_profiler._install_hooks() + super().__enter__() + + def __exit__(self, exc_type, exc_val, exc_tb): + super().__exit__(exc_type, exc_val, exc_tb) + self.main_profiler._remove_hooks() + torch.cuda.synchronize() # Wait for the events to be recorded + for op in self.trace: + op.finalize() + self.save_json() + + def step(self) -> None: + pass + + def save_json(self) -> None: + # Aggregate data at the module + op level + all_paths: Set[Tuple[str, ...]] = set() + per_module_data: Dict[Tuple[str, ...], _OpInfoAggregated] = defaultdict( + _OpInfoAggregated + ) + per_op_data: Dict[str, _OpInfoAggregated] = defaultdict(_OpInfoAggregated) + for op in self.trace: + all_paths.add(op.stacktrace) + for op in self.trace: + for i in range(len(op.stacktrace)): + if op.stacktrace[: i + 1] in all_paths: + per_module_data[op.stacktrace[: i + 1]].add(op) + per_op_data[op.op_name].add(op) + + # Generate JSON + all_data = [] + for stacktrace, agg_info in per_module_data.items(): + all_data.append( + agg_info.as_dict( + agg="module", path=stacktrace, name=stacktrace[-1], op="" + ) + ) + for op_name, agg_info in per_op_data.items(): + # Find the most common path + paths_count: Dict[Tuple[str, ...], int] = defaultdict(int) + agg_info.stacktraces.sort() # In case of a draw, let's always return the same + for p in agg_info.stacktraces: + paths_count[p] += 1 + maxp = agg_info.stacktraces[0] + for p, count in paths_count.items(): + if count > paths_count[maxp]: + maxp = p + all_data.append( + agg_info.as_dict( + agg="opname", + path=f"{'.'.join(maxp)} (x{paths_count[maxp]})", + name="", + op=op_name, + ) + ) + + filename = self.main_profiler._create_output_filename("ops.json") + self.main_profiler.summary.append(("OpsSummary", str(filename))) + with open(filename, "w+") as f: + json.dump(all_data, f) diff --git a/pkgs/xformers/sparse/__init__.py b/pkgs/xformers/sparse/__init__.py new file mode 100644 index 0000000..7238f5b --- /dev/null +++ b/pkgs/xformers/sparse/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from .blocksparse_tensor import BlockSparseTensor # noqa: F401 +from .csr_tensor import SparseCSRTensor # noqa: F401 diff --git a/pkgs/xformers/sparse/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/sparse/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..aafcf3e Binary files /dev/null and b/pkgs/xformers/sparse/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/sparse/__pycache__/_csr_ops.cpython-310.pyc b/pkgs/xformers/sparse/__pycache__/_csr_ops.cpython-310.pyc new file mode 100644 index 0000000..821303b Binary files /dev/null and b/pkgs/xformers/sparse/__pycache__/_csr_ops.cpython-310.pyc differ diff --git a/pkgs/xformers/sparse/__pycache__/blocksparse_tensor.cpython-310.pyc b/pkgs/xformers/sparse/__pycache__/blocksparse_tensor.cpython-310.pyc new file mode 100644 index 0000000..cf692e5 Binary files /dev/null and b/pkgs/xformers/sparse/__pycache__/blocksparse_tensor.cpython-310.pyc differ diff --git a/pkgs/xformers/sparse/__pycache__/csr_tensor.cpython-310.pyc b/pkgs/xformers/sparse/__pycache__/csr_tensor.cpython-310.pyc new file mode 100644 index 0000000..b49ef12 Binary files /dev/null and b/pkgs/xformers/sparse/__pycache__/csr_tensor.cpython-310.pyc differ diff --git a/pkgs/xformers/sparse/__pycache__/utils.cpython-310.pyc b/pkgs/xformers/sparse/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..0e3f911 Binary files /dev/null and b/pkgs/xformers/sparse/__pycache__/utils.cpython-310.pyc differ diff --git a/pkgs/xformers/sparse/_csr_ops.py b/pkgs/xformers/sparse/_csr_ops.py new file mode 100644 index 0000000..3b3115e --- /dev/null +++ b/pkgs/xformers/sparse/_csr_ops.py @@ -0,0 +1,166 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from .utils import _csr_to_coo, _transpose_with_info + + +def _should_use_coo(a, sparsity): + if not a.is_cuda: + return False + B, M, K = a.shape + # amortize overhead of converting from csr to coo + if B < 32 and M < 4096: + return False + if sparsity > 0.995: + return False + if sparsity < 0.9: + return False + if K > 64: + return False + # let's be overly cautious here for now + return sparsity > 0.97 + + +def _should_use_csr_ge(a, sparsity): + if not a.is_cuda: + return False + return sparsity > 0.99 + + +def _sddmm_func(a, b, row_indices, row_offsets, column_indices): + sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1]) + if _should_use_coo(a, sparsity): + m = a.shape[-2] + n = b.shape[-2] + # converting from csr to coo has a constant overhead of ~150us + # so only dispatch to it for reasonably large problem sizes + ro, ci = _csr_to_coo(m, n, row_offsets, column_indices) + return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci) + elif _should_use_csr_ge(a, sparsity): + return torch.ops.xformers.csr_sddmm( + a, b, row_indices, row_offsets, column_indices + ) + return torch.ops.xformers.sddmm_sputnik( + a, b, row_indices, row_offsets, column_indices + ) + + +class _SparseSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, m, n, row_indices, values, row_offsets, column_indices): + out = torch.ops.xformers.sparse_softmax_sputnik( + m, n, row_indices, values, row_offsets, column_indices + ) + # note: save out and not values, as an optimization step + ctx.save_for_backward(row_indices, out, row_offsets, column_indices) + ctx.size = (m, n) + return out + + @staticmethod + def backward(ctx, grad): + row_indices, out, row_offsets, column_indices = ctx.saved_tensors + m, n = ctx.size + + # gradients w.r.t. values + grad = grad.contiguous() + ga = torch.ops.xformers.sparse_softmax_backward_sputnik( + m, n, row_indices, out, grad, row_offsets, column_indices + ) + + return None, None, None, ga, None, None + + +class _sddmm(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b, row_indices, row_offsets, column_indices, _transp_info): + out = _sddmm_func(a, b, row_indices, row_offsets, column_indices) + + ctx.save_for_backward( + a, b, row_indices, row_offsets, column_indices, *_transp_info + ) + return out + + @staticmethod + def backward(ctx, grad): + ( + a, + b, + row_indices, + row_offsets, + column_indices, + *_transp_info, + ) = ctx.saved_tensors + m, n = a.shape[1], b.shape[1] + + # gradients w.r.t. values + grad = grad.contiguous() + a = a.contiguous() + b = b.contiguous() + + a_grad = torch.ops.xformers.spmm_sputnik( + b, row_indices, grad, row_offsets, column_indices, m + ) + + ( + row_indices_t, + grad_t, + row_offsets_t, + column_indices_t, + ) = _transpose_with_info(grad, _transp_info) + + b_grad = torch.ops.xformers.spmm_sputnik( + a, row_indices_t, grad_t, row_offsets_t, column_indices_t, n + ) + + return a_grad, b_grad, None, None, None, None + + +class _spmm(torch.autograd.Function): + @staticmethod + def forward( + ctx, b, row_indices, values, row_offsets, column_indices, m, _transp_info + ): + b = b.contiguous() + out = torch.ops.xformers.spmm_sputnik( + b, row_indices, values, row_offsets, column_indices, m + ) + + ctx.save_for_backward( + b, row_indices, values, row_offsets, column_indices, *_transp_info + ) + return out + + @staticmethod + def backward(ctx, grad): + ( + b, + row_indices, + values, + row_offsets, + column_indices, + *_transp_info, + ) = ctx.saved_tensors + k = b.shape[1] + + # gradients w.r.t. values + grad = grad.contiguous() + + grad_sparse = _sddmm_func(grad, b, row_indices, row_offsets, column_indices) + + ( + row_indices_t, + values_t, + row_offsets_t, + column_indices_t, + ) = _transpose_with_info(values, _transp_info) + + grad_dense = torch.ops.xformers.spmm_sputnik( + grad, row_indices_t, values_t, row_offsets_t, column_indices_t, k + ) + + return grad_dense, None, grad_sparse, None, None, None, None diff --git a/pkgs/xformers/sparse/blocksparse_tensor.py b/pkgs/xformers/sparse/blocksparse_tensor.py new file mode 100644 index 0000000..c70c448 --- /dev/null +++ b/pkgs/xformers/sparse/blocksparse_tensor.py @@ -0,0 +1,357 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +from xformers import _is_triton_available +from xformers.ops import masked_matmul + +logger = logging.getLogger("xformers") + + +try: + if not _is_triton_available(): + raise ImportError("triton is not available") + from triton.ops.blocksparse import matmul as blocksparse_matmul + from triton.ops.blocksparse import softmax as blocksparse_softmax +except ImportError as e: + logger.warning( + "Triton is not available, some optimizations will not be enabled.\n" + + f"This is just a warning: {e}" + ) + blocksparse_matmul = None + blocksparse_softmax = None + + +def _can_use_triton(a): + if a.device.type == "cpu": + return False + + if blocksparse_matmul is None: + return False + + return True + + +def _spmm(b, layout, values): + N, nnz, _, block_size = values.shape + br = b.reshape( + b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3] + ) + # perform matmul on blocks + h, r, c = layout.nonzero(as_tuple=True) + temp = values @ br[:, h, c, :] + + linear_idx = h * (b.shape[2] // block_size) + r + out = torch.zeros( + N, + b.shape[1] * layout.shape[-2], + block_size, + b.shape[3], + dtype=b.dtype, + device=b.device, + ) + # now aggregate the results of the different blocks + out.index_add_(1, linear_idx.to(b.device), temp) + out = out.reshape(N, b.shape[1], -1, b.shape[3]) + return out + + +def _softmax(layout, values): + h, r, c = layout.nonzero(as_tuple=True) + norms = torch.logsumexp(values, dim=-1, keepdim=True) + linear_idx = h * layout.shape[1] + r + + out_t = torch.zeros( + norms.shape[0], + layout.shape[0] * layout.shape[1], + norms.shape[2], + norms.shape[3], + dtype=norms.dtype, + device=norms.device, + ) + max_val = norms.max() + out_t.index_add_( + 1, linear_idx.to(values.device), (norms - max_val).exp() + ).clamp_min_(1e-24).log_().add_(max_val) + out = torch.exp(values - out_t[:, linear_idx]) + return out + + +def _sddmm(a, b, layout): + block_size = a.shape[-2] // layout.shape[-2] + a = a.reshape( + a.shape[0], a.shape[1], a.shape[2] // block_size, block_size, a.shape[3] + ) + b = b.reshape( + b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3] + ) + + h, r, c = layout.nonzero(as_tuple=True) + + out = torch.einsum("nhik,nhjk->nhij", a[:, h, r, :, :], b[:, h, c, :, :]) + return out + + +class BlockSparseTensor(torch.Tensor): + @staticmethod + def __new__(cls, values, layout): + kwargs = {} + kwargs["device"] = values.device + kwargs["dtype"] = values.dtype + kwargs["layout"] = values.layout + kwargs["requires_grad"] = values.requires_grad + assert values.ndim == 4 + B, _, block_size, _ = values.shape + C, h, w = layout.shape + # TODO validate shape of layout vs values + shape = (B, C, block_size * h, block_size * w) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, values, layout): + assert values.shape[-2] == values.shape[-1] + assert ( + values.device == layout.device + ), "Both values and layout need to reside on the same device" + block_size = values.shape[-1] + # TODO: make this check conditioned on the use of Triton + assert block_size >= 16, "Minimum block size is 16, for now at least" + + # Pure blocksparse data + self.__values = values + self.__layout = layout + + # blocksparse operators for triton + if blocksparse_matmul: + self._initialize_triton_ops() + else: + self.__sparse_dot_sdd = None + self.__sparse_dot_dsd = None + self.__sparse_softmax = None + + def _initialize_triton_ops(self): + block_size = self.__values.shape[-1] + self.__sparse_dot_sdd = blocksparse_matmul( + self.__layout, + block_size, + "sdd", + trans_a=False, + trans_b=True, + device=self.__layout.device, + ) + self.__sparse_dot_dsd = blocksparse_matmul( + self.__layout, + block_size, + "dsd", + trans_a=False, + trans_b=False, + device=self.__layout.device, + ) + self.__sparse_softmax = blocksparse_softmax( + self.__layout, block_size, device=self.__layout.device + ) + + def __repr__(self): + return f"block_sparse_tensor(shape={self.shape}, values={self.__values})" + + def values(self): + return self.__values + + @classmethod + def _raw_wrap(cls, values, layout, sparse_dot_sdd, sparse_dot_dsd, sparse_softmax): + matrix = cls.__new__(cls, values, layout) + matrix.__values = values + matrix.__layout = layout + matrix.__sparse_dot_sdd = sparse_dot_sdd + matrix.__sparse_dot_dsd = sparse_dot_dsd + matrix.__sparse_softmax = sparse_softmax + return matrix + + @classmethod + def _wrap(cls, values, bmat): + matrix = cls.__new__(cls, values, bmat.__layout) + matrix.__values = values + matrix.__layout = bmat.__layout + matrix.__sparse_dot_sdd = bmat.__sparse_dot_sdd + matrix.__sparse_dot_dsd = bmat.__sparse_dot_dsd + matrix.__sparse_softmax = bmat.__sparse_softmax + return matrix + + @classmethod + def _bmm(cls, arg0, arg1): + if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor): + return NotImplemented + if _can_use_triton(arg1): + res = arg0.__sparse_dot_dsd(arg0.__values, arg1) + else: + res = _spmm(arg1, arg0.__layout, arg0.__values) + return res + + @classmethod + def _masked_matmul(cls, a, b, mask): + if not (type(a) == torch.Tensor and type(b) == torch.Tensor): + return NotImplemented + b = b.transpose(-2, -1) + assert b.is_contiguous() + if _can_use_triton(a): + res = mask.__sparse_dot_sdd(a, b) + else: + res = _sddmm(a, b, mask.__layout) + return cls._wrap(res, mask) + + @classmethod + def _softmax(cls, arg0, dim): + if not (dim == -1 or dim == 2): + return NotImplemented + if _can_use_triton(arg0): + res = arg0.__sparse_softmax(arg0.__values) + else: + res = _softmax(arg0.__layout, arg0.__values) + return cls._wrap(res, arg0) + + @classmethod + def _to(cls, arg0, device): + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + return cls( + arg0.__values.to(device=device), + arg0.__layout, + ) + + @classmethod + def _copy(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + assert arg0.shape == arg1.shape + av0, av1 = arg0.__values, arg1.__values + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__layout, arg1.__layout + av0.resize_as_(av1).copy_(av1) + out = cls(arg0.__values, arg0.__layout) + arg0.__sparse_dot_sdd = out.__sparse_dot_sdd + arg0.__sparse_dot_dsd = out.__sparse_dot_dsd + arg0.__sparse_softmax = out.__sparse_softmax + return arg0 + + @classmethod + def _equal(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + if arg0.shape != arg1.shape: + return False + if not torch.equal(arg0.__values, arg1.__values): + return False + if not torch.equal(arg0.__layout, arg1.__layout): + return False + return True + + @classmethod + def _to_dense(cls, arg0): + # out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device, requires_grad=arg0.requires_grad) + out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device) + values = arg0.__values + layout = arg0.__layout + block_size = values.shape[-1] + blocks_i = layout.shape[-2] + blocks_j = layout.shape[-1] + + out_r = out.reshape( + arg0.shape[0], arg0.shape[1], blocks_i, block_size, blocks_j, block_size + ) + + for idx, (h, i, j) in enumerate(zip(*layout.nonzero(as_tuple=True))): + out_r[:, h, i, :, j, :] = values[:, idx, :, :] + + return out + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in [ + torch.Tensor.bmm, + torch.bmm, + torch.Tensor.__matmul__, + torch.matmul, + torch.Tensor.matmul, + ]: + assert len(args) == 2 + return cls._bmm(args[0], args[1]) + + if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]: + return cls._softmax(args[0], kwargs["dim"]) + + if func == masked_matmul: + assert len(args) == 3 + return cls._masked_matmul(args[0], args[1], args[2]) + + if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]: + x = args[0] + values = x.__values.clone() + values = func(values, *args[1:], **kwargs) + return cls._wrap(values, x) + + if func == torch.Tensor.to: + # print(args, kwargs) + assert len(args) >= 2 + return cls._to(args[0], args[1]) + # return cls._to(args[0], kwargs["device"]) + + if func in [torch.Tensor.copy_]: + assert len(args) == 2 + return cls._copy(args[0], args[1]) + + if func in [torch.Tensor.equal, torch.equal]: + assert len(args) == 2 + return cls._equal(args[0], args[1]) + + if func == torch.Tensor.to_dense: + assert len(args) == 1 + return cls._to_dense(args[0]) + + if func == torch.Tensor.detach: + x = args[0] + values = x.__values.clone() + values = func(values, *args[1:], **kwargs) + return cls._wrap(values, x) + + if func == torch.Tensor.__deepcopy__: + x = args[0] + memo = args[1] + return cls._raw_wrap( + x.__values.__deepcopy__(memo), + x.__layout.__deepcopy__(memo), + # x.__sparse_dot_sdd.__deepcopy__(memo), + # x.__sparse_dot_dsd.__deepcopy__(memo), + # x.__sparse_softmax.__deepcopy__(memo), + x.__sparse_dot_sdd, + x.__sparse_dot_dsd, + x.__sparse_softmax, + ) + + if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]: + assert len(args) == 1 + assert len(kwargs) == 0 + x = args[0] + return cls._wrap(x.__values.grad, x) + + if func == torch.Tensor.requires_grad_: + func(args[0].__values) + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + # TODO: check this + if func in torch.overrides.get_default_nowrap_functions(): + return ret + return torch._tensor._convert(ret, cls) + + return NotImplemented + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + return NotImplemented diff --git a/pkgs/xformers/sparse/csr_tensor.py b/pkgs/xformers/sparse/csr_tensor.py new file mode 100644 index 0000000..a6897de --- /dev/null +++ b/pkgs/xformers/sparse/csr_tensor.py @@ -0,0 +1,438 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.ops import masked_matmul +from xformers.sparse import _csr_ops +from xformers.sparse.utils import ( + _csr_to_coo, + _dense3d_to_sparse, + _diffsort, + _get_transpose_info, + _transpose_with_info, +) + + +class SparseCSRTensor(torch.Tensor): + @staticmethod + def __new__(cls, row_offsets, column_indices, values, shape): + kwargs = {} + kwargs["device"] = values.device + kwargs["dtype"] = values.dtype + kwargs["layout"] = values.layout + kwargs["requires_grad"] = values.requires_grad + assert len(shape) == 3 + assert torch.__version__ > (1, 10), "SparseCSRTensor requires PyTorch 1.11+" + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, row_offsets, column_indices, values, shape): + assert row_offsets.ndim == 1 + assert column_indices.ndim == 1 + assert values.ndim == 2 + + self.__row_offsets = row_offsets.contiguous() + self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype) + self.__column_indices = column_indices.contiguous() + self.__values = values.contiguous() + + self.__transp_info = _get_transpose_info( + self.shape[1], + self.shape[2], + self.__row_indices, + self.__row_offsets, + self.__column_indices, + ) + + def __repr__(self): + return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})" + + @classmethod + def from_dense(cls, matrix): + values, row_indices, row_offsets, column_indices = _dense3d_to_sparse( + matrix, matrix.device + ) + return cls(row_offsets, column_indices, values, matrix.shape) + + @classmethod + def from_sparse_coo(cls, arg0): + """ + assert arg0.is_sparse + x = arg0.coalesce() + rows, cols = x.indices().unbind(0) + vals = x.values() + _coo_to_csr() + """ + pass + + @classmethod + def _wrap( + cls, shape, values, row_indices, row_offsets, column_indices, _transp_info + ): + matrix = cls.__new__(cls, row_offsets, column_indices, values, shape) + matrix.__values = values + matrix.__row_indices = row_indices + matrix.__row_offsets = row_offsets + matrix.__column_indices = column_indices + matrix.__transp_info = _transp_info + return matrix + + def values(self): + return self.__values + + @property + def _csr_row_indices(self): + return self.__row_indices + + @property + def _csr_row_offsets(self): + return self.__row_offsets + + @property + def _csr_column_indices(self): + return self.__column_indices + + @property + def _csr_transp_info(self): + return self.__transp_info + + @classmethod + def _bmm(cls, arg0, arg1): + if not (isinstance(arg0, cls) and type(arg1) == torch.Tensor): + return NotImplemented + + assert arg0.ndim == 3 + assert arg1.ndim == 3 + + self = arg0 + b = arg1 + + _, m, n = self.shape + row_indices = self.__row_indices + values = self.__values + row_offsets = self.__row_offsets + column_indices = self.__column_indices + + out = _csr_ops._spmm.apply( + b, row_indices, values, row_offsets, column_indices, m, self.__transp_info + ) + return out + + @classmethod + def _softmax(cls, arg0, dim): + if not (dim == -1 or dim == 2): + return NotImplemented + + self = arg0 + _, m, n = self.shape + row_indices = self.__row_indices + values = self.__values + row_offsets = self.__row_offsets + column_indices = self.__column_indices + out = _csr_ops._SparseSoftmax.apply( + m, n, row_indices, values, row_offsets, column_indices + ) + return cls._wrap( + self.shape, + out, + row_indices, + row_offsets, + column_indices, + self.__transp_info, + ) + + @classmethod + def _transpose(cls, arg0, dim0, dim1): + # TODO: check if need to return this or not + if not (dim0 == 1 or dim0 == -2): + return NotImplemented + if not (dim1 == 2 or dim1 == -1): + return NotImplemented + + B, m, n = arg0.shape + values = arg0.__values + + ( + output_row_indices, + output_values, + output_row_offsets, + output_column_indices, + ) = _transpose_with_info(values, arg0.__transp_info) + new_transp_info = _get_transpose_info( + n, m, output_row_indices, output_row_offsets, output_column_indices + ) + + return cls._wrap( + (B, n, m), + output_values, + output_row_indices, + output_row_offsets, + output_column_indices, + new_transp_info, + ) + + @classmethod + def _masked_matmul(cls, a, b, mask): + if not (type(a) == torch.Tensor and type(b) == torch.Tensor): + return NotImplemented + assert mask.shape[1] == a.shape[1] + assert mask.shape[2] == b.shape[2] + row_indices = mask.__row_indices + row_offsets = mask.__row_offsets + column_indices = mask.__column_indices + a = a.contiguous() + out = _csr_ops._sddmm.apply( + a, + b.transpose(-2, -1).contiguous(), + row_indices, + row_offsets, + column_indices, + mask.__transp_info, + ) + # TODO add bias here + return cls._wrap( + mask.shape, + out, + row_indices, + row_offsets, + column_indices, + mask.__transp_info, + ) + + @classmethod + def _to(cls, arg0, device): + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + return cls._wrap( + arg0.shape, + arg0.__values.to(device=device), + arg0.__row_indices.to(device=device), + arg0.__row_offsets.to(device=device), + arg0.__column_indices.to(device=device), + tuple(t.to(device=device) for t in arg0.__transp_info), + ) + + @classmethod + def _copy(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + assert arg0.shape == arg1.shape + av0, av1 = arg0.__values, arg1.__values + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__row_indices, arg1.__row_indices + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__row_offsets, arg1.__row_offsets + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__column_indices, arg1.__column_indices + av0.resize_as_(av1).copy_(av1) + for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info): + v0.resize_as_(v1).copy_(v1) + return arg0 + + @classmethod + def _equal(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + if arg0.shape != arg1.shape: + return False + if not torch.equal(arg0.__values, arg1.__values): + return False + if not torch.equal(arg0.__row_offsets, arg1.__row_offsets): + return False + if not torch.equal(arg0.__column_indices, arg1.__column_indices): + return False + return True + + @classmethod + def _to_dense(cls, arg0): + _, m, n = arg0.shape + shape = arg0.shape + matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device) + row_offsets = arg0.__row_offsets.long() + column_indices = arg0.__column_indices.long() + row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices) + b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None] + matrix[b_idxs, row_coo, column_indices] = arg0.__values + return matrix + + @classmethod + def _binary_op(cls, func, arg0, arg1): + if not ( + isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float)) + ): + return NotImplemented + v0, v1 = arg0, arg1 + if isinstance(arg0, cls): + v0 = arg0.__values + if isinstance(arg1, cls): + v1 = arg1.__values + # assert arg0.shape == arg1.shape + if isinstance(arg0, cls) and isinstance(arg1, cls): + msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)" + if not arg0.__row_offsets.shape == arg1.__row_offsets.shape: + raise NotImplementedError(msg) + if not arg0.__column_indices.shape == arg1.__column_indices.shape: + raise NotImplementedError(msg) + if not arg0.__values.shape == arg1.__values.shape: + raise NotImplementedError(msg) + # TODO this is not always true, but is a fast approximation for now + if arg0.__row_offsets is not arg1.__row_offsets: + raise NotImplementedError(msg) + if arg0.__column_indices is not arg1.__column_indices: + raise NotImplementedError(msg) + out = func(v0, v1) + return cls._wrap( + arg0.shape, + out, + arg0.__row_indices, + arg0.__row_offsets, + arg0.__column_indices, + arg0.__transp_info, + ) + + @classmethod + def _binary_op_slow(cls, func, arg0, arg1): + # assert arg0.shape == arg1.shape + v0, v1 = arg0, arg1 + if isinstance(arg0, cls): + v0 = arg0.to_dense() + if isinstance(arg1, cls): + v1 = arg1.to_dense() + out = func(v0, v1) + return cls.from_dense(out) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in [ + torch.Tensor.bmm, + torch.bmm, + torch.Tensor.__matmul__, + torch.matmul, + torch.Tensor.matmul, + ]: + assert len(args) == 2 + return cls._bmm(args[0], args[1]) + + if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]: + return cls._softmax(args[0], kwargs["dim"]) + + if func in [torch.Tensor.transpose, torch.transpose]: + assert len(kwargs) == 0 + return cls._transpose(args[0], args[1], args[2]) + + if func == masked_matmul: + assert len(args) == 3 + return cls._masked_matmul(args[0], args[1], args[2]) + + if func in [ + torch.Tensor.add, + torch.add, + torch.Tensor.__add__, + ]: + assert len(args) == 2 + if not (isinstance(args[0], cls) and isinstance(args[1], cls)): + raise NotImplementedError( + f"{func} with {type(args[0])} and {type(args[1])} not implemented" + ) + return cls._binary_op(func, args[0], args[1]) + + if func in [ + torch.Tensor.mul, + torch.mul, + torch.Tensor.__mul__, + ]: + assert len(args) == 2 + return cls._binary_op(func, args[0], args[1]) + + if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]: + assert len(args) == 2 + return cls._binary_op_slow(func, args[0], args[1]) + + if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]: + x = args[0] + values = x.__values.clone() + values = func(values, *args[1:], **kwargs) + return cls._wrap( + x.shape, + values, + x.__row_indices, + x.__row_offsets, + x.__column_indices, + x.__transp_info, + ) + + if func == torch.Tensor.to: + # print(args, kwargs) + assert len(args) >= 2 + return cls._to(args[0], args[1]) + # return cls._to(args[0], kwargs["device"]) + + if func in [torch.Tensor.copy_]: + assert len(args) == 2 + return cls._copy(args[0], args[1]) + + if func in [torch.Tensor.equal, torch.equal]: + assert len(args) == 2 + return cls._equal(args[0], args[1]) + + if func == torch.Tensor.to_dense: + assert len(args) == 1 + return cls._to_dense(args[0]) + + if func == torch.Tensor.detach: + x = args[0] + return cls._wrap( + x.shape, + x.__values.detach(), + x.__row_indices, + x.__row_offsets, + x.__column_indices, + x.__transp_info, + ) + + if func == torch.Tensor.__deepcopy__: + x = args[0] + memo = args[1] + return cls._wrap( + x.shape, + x.__values.__deepcopy__(memo), + x.__row_indices.__deepcopy__(memo), + x.__row_offsets.__deepcopy__(memo), + x.__column_indices.__deepcopy__(memo), + tuple(v.__deepcopy__(memo) for v in x.__transp_info), + ) + + if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]: + assert len(args) == 1 + assert len(kwargs) == 0 + x = args[0] + return cls._wrap( + x.shape, + x.__values.grad, + x.__row_indices, + x.__row_offsets, + x.__column_indices, + x.__transp_info, + ) + + if func == torch.Tensor.requires_grad_: + func(args[0].__values) + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + # TODO: check this + if func in torch.overrides.get_default_nowrap_functions(): + return ret + return torch._tensor._convert(ret, cls) + + return NotImplemented + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + return NotImplemented diff --git a/pkgs/xformers/sparse/utils.py b/pkgs/xformers/sparse/utils.py new file mode 100644 index 0000000..0031e4b --- /dev/null +++ b/pkgs/xformers/sparse/utils.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + + +def _coo_to_csr(m, n, row_indices, column_indices): + # assumes coalesced coo + row_offsets = row_indices.bincount(minlength=n).cumsum(0, dtype=row_indices.dtype) + row_offsets = torch.nn.functional.pad(row_offsets, (1, 0)) + return row_offsets, column_indices + + +def _csr_to_coo(m, n, row_offsets, column_indices): + # convert from compressed rows to uncompressed + indices = torch.arange(m, dtype=row_offsets.dtype, device=row_offsets.device) + row_sizes = torch.diff(row_offsets) + row_coo = torch.repeat_interleave(indices, row_sizes.long()) + return row_coo, column_indices + + +def _diffsort(a): + return torch.argsort(torch.diff(a), dim=0, descending=True) + + +def _get_transpose_info(m, n, row_indices, row_offsets, column_indices): + # strategy: + # - uncompress the rows to have data in COO format + # - get permutation for stable sort of the columns to get the rows for the transposed matrix + # - compress the new rows and return the permutation to be applied on the values + + # convert from compressed rows to uncompressed + row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices) + + # get the permutation for the stable sort + row_offsets_t, perm = column_indices.sort(dim=0, stable=True) + column_indices_t = row_coo[perm] + + row_offsets_t, _ = _coo_to_csr(m, n, row_offsets_t, column_indices) + row_indices_t = _diffsort(row_offsets_t).int() + + return row_indices_t, row_offsets_t, column_indices_t, perm + + +def _transpose_with_info(values, _transpose_info): + row_indices_t, row_offsets_t, column_indices_t, perm = _transpose_info + values_t = values[:, perm] + return row_indices_t, values_t, row_offsets_t, column_indices_t + + +def _transpose(m, n, row_indices, values, row_offsets, column_indices): + _transpose_info = _get_transpose_info( + m, n, row_indices, row_offsets, column_indices + ) + return _transpose_with_info(values, _transpose_info) + + +def _nonzero_mask_to_sparse_csr_indices(mask, device): + """Converts dense 2d matrix to a csr sparse matrix.""" + + assert len(mask.shape) == 2 + index_dtype = torch.int32 + + # Calculate the offset of each row. + row_offsets = mask.sum(dim=-1, dtype=index_dtype).cumsum(dim=-1, dtype=index_dtype) + row_offsets = torch.nn.functional.pad(row_offsets, (1, 0)) + + # Create the row indices and sort them. + row_indices = _diffsort(row_offsets).to(index_dtype) + + # Extract the column indices for the nonzero values. + column_indices = torch.where(mask)[1].to(index_dtype).contiguous() + + row_indices = row_indices.to(device) + row_offsets = row_offsets.to(device) + column_indices = column_indices.to(device) + return row_indices, row_offsets, column_indices + + +def _dense_to_sparse(matrix, device): + """Converts dense 2d matrix to a csr sparse matrix.""" + + assert len(matrix.shape) == 2 + value_dtype = torch.float32 + + # Extract the nonzero values. + mask = matrix != 0 + values = matrix[mask].to(dtype=value_dtype, device=device) + + row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices( + mask, device + ) + return values, row_indices, row_offsets, column_indices + + +def _round_nnz(mask, divisible_by=4): + nonzero = torch.where(mask) + nnz = nonzero[0].shape[0] + nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero) + nm = torch.zeros_like(mask) + nm[nonzero] = True + return nm + + +def _dense3d_to_sparse(matrix, device): + assert len(matrix.shape) == 3 + mask = matrix != 0 + if not torch.all(mask == mask[0]): + raise ValueError("Expected the same sparsity pattern over the batch dimension") + + # for now, our kernels assume that we have the number of + # nnz to be divisible by 4 + mask = _round_nnz(mask[0], divisible_by=4) + mask = mask[None].expand(matrix.shape) + + values = matrix[mask].reshape(matrix.shape[0], -1).to(device) + row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices( + mask[0], device + ) + return values, row_indices, row_offsets, column_indices diff --git a/pkgs/xformers/test.py b/pkgs/xformers/test.py new file mode 100644 index 0000000..6677db0 --- /dev/null +++ b/pkgs/xformers/test.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pkgs/xformers/triton/__init__.py b/pkgs/xformers/triton/__init__.py new file mode 100644 index 0000000..4442204 --- /dev/null +++ b/pkgs/xformers/triton/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +_triton_available = torch.cuda.is_available() +if _triton_available: + try: + from .dropout import FusedDropoutBias, dropout # noqa + from .fused_linear_layer import FusedLinear # noqa + from .layer_norm import FusedLayerNorm, layer_norm # noqa + from .softmax import log_softmax, softmax # noqa + + __all__ = [ + "dropout", + "softmax", + "log_softmax", + "FusedDropoutBias", + "FusedLinear", + "FusedLayerNorm", + "layer_norm", + ] + except ImportError: + __all__ = [] diff --git a/pkgs/xformers/triton/__pycache__/__init__.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..7397e0b Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/__init__.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/dropout.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/dropout.cpython-310.pyc new file mode 100644 index 0000000..8c851c4 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/dropout.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/fused_linear_layer.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/fused_linear_layer.cpython-310.pyc new file mode 100644 index 0000000..56093a8 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/fused_linear_layer.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_activations.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_activations.cpython-310.pyc new file mode 100644 index 0000000..d9f2c11 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_activations.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_dropout.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_dropout.cpython-310.pyc new file mode 100644 index 0000000..6efc564 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_dropout.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_fused_matmul_bw.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_fused_matmul_bw.cpython-310.pyc new file mode 100644 index 0000000..7e7f176 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_fused_matmul_bw.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_fused_matmul_fw.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_fused_matmul_fw.cpython-310.pyc new file mode 100644 index 0000000..fbc1b95 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_fused_matmul_fw.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_layer_norm.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_layer_norm.cpython-310.pyc new file mode 100644 index 0000000..8441eb1 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_layer_norm.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_softmax.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_softmax.cpython-310.pyc new file mode 100644 index 0000000..f8d0d01 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_softmax.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/k_sum.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/k_sum.cpython-310.pyc new file mode 100644 index 0000000..ef8a920 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/k_sum.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/layer_norm.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/layer_norm.cpython-310.pyc new file mode 100644 index 0000000..fd6cb50 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/layer_norm.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/softmax.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/softmax.cpython-310.pyc new file mode 100644 index 0000000..18176e5 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/softmax.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/sum_strided.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/sum_strided.cpython-310.pyc new file mode 100644 index 0000000..45d9c76 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/sum_strided.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/utils.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..d198331 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/utils.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/__pycache__/vararg_kernel.cpython-310.pyc b/pkgs/xformers/triton/__pycache__/vararg_kernel.cpython-310.pyc new file mode 100644 index 0000000..8f104c2 Binary files /dev/null and b/pkgs/xformers/triton/__pycache__/vararg_kernel.cpython-310.pyc differ diff --git a/pkgs/xformers/triton/dropout.py b/pkgs/xformers/triton/dropout.py new file mode 100644 index 0000000..d36a056 --- /dev/null +++ b/pkgs/xformers/triton/dropout.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: This is heavily inspired by the Triton dropout tutorial +# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py + +from typing import Optional + +import torch +import triton +from torch.cuda.amp import custom_bwd, custom_fwd + +from xformers.components.activations import Activation, build_activation +from xformers.triton.k_activations import get_triton_activation_index +from xformers.triton.k_dropout import k_dropout_bw, k_dropout_fw + +BLOCK_M = 32 +BLOCK_N = 64 # NOTE: This should ideally be GPU dependent, big impact on perf + + +# Helper to handle the SPMD launch grid and error cases +class _dropout(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, x, p, bias, activation, trainable_bias): + # Soft-flatten an hypothetical 3rd dimension + x_ = x.reshape(-1, x.shape[-1]).contiguous() + y = torch.empty_like(x_) + M, N = x_.shape + + assert bias is None or (bias.dtype == x.dtype and bias.shape[0] == N) + assert p > 0.0 + + def grid(meta): + return ( + triton.cdiv(M, meta["BLOCK_M"]), + triton.cdiv(N, meta["BLOCK_N"]), + ) + + N_BLOCK_N = triton.cdiv(N, BLOCK_N) + + # Generate one seed per sample + # seed max is int32 max for positive numbers: 2**16 + seeds = torch.randint(65536, (N_BLOCK_N,), device=x.device, dtype=torch.int32) + + # fmt: off + bias_ptr = bias if bias is not None else x_ # Possibly not being used + + k_dropout_fw[grid]( + y, x_, + bias_ptr, + seeds, + y.stride(0), + M, N, + p, + x.dtype == torch.float16, + USE_BIAS=bias is not None, + ACTIVATION=activation, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + # fmt: on + + if activation is not None: + ctx.save_for_backward(seeds, bias, x) + else: + ctx.save_for_backward(seeds, bias, None) + + ctx.trainable_bias = bias is not None and trainable_bias + ctx.activation = activation + ctx.p = p + + return y.reshape_as(x) + + @staticmethod + @custom_bwd + def backward( + ctx, grad_out + ): # pragma: no cover # This is covered, but called from C++ and not tracked + (seeds, bias, inputs) = ctx.saved_tensors + + # Soft-flatten an hypothetical 3rd dimension + grad_out_ = grad_out.reshape(-1, grad_out.shape[-1]).contiguous() + grad_in = torch.empty_like(grad_out_) + + M, N = grad_out_.shape + + # Optional inputs to compute the activation contribution to the gradient + assert inputs is not None or ctx.activation is None + + if inputs is None: + inputs = grad_out_ + elif inputs.ndim > 2: + inputs = inputs.reshape(-1, N) + + # We split the problem in tiles: + # - over M there will be a follow up reduction + # - over N we compromise in between trying to use as much memory paralellism as possible, + # (fill in the warps, there are 32 threads per warps, and 4 warps default), and not being too + # big because of register spilling + N_BLOCKS_M = triton.cdiv(M, BLOCK_M) + + if ctx.trainable_bias: + grad_bias = torch.empty( + ( + N_BLOCKS_M, + N, + ), + device=grad_in.device, + dtype=grad_in.dtype, + ) + + else: + grad_bias = grad_in # will not be used + + def grid(meta): + # NOTE: We use Triton Philox random number generator, which optimally generates 4 blocks for + # a given seed and offsets. "BLOCK_M" here describes the size of one of these blocks + # but we need to take this factor of 4 into account when scheduling all the kernels + return ( + N_BLOCKS_M, + triton.cdiv(N, meta["BLOCK_N"]), + ) + + # fmt: off + k_dropout_bw[grid]( + grad_in, grad_bias, grad_out_, + inputs, bias if bias is not None else inputs, + seeds, + grad_out_.stride(0), inputs.stride(0), + M, N, + ctx.p, + grad_in.dtype == torch.float16, + USE_BIAS=bias is not None, + ACTIVATION=ctx.activation, + TRAINABLE_BIAS=ctx.trainable_bias, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + # fmt: on + + return ( + grad_in.reshape_as(grad_out), + None, + torch.sum(grad_bias, dim=0) if ctx.trainable_bias else None, + None, + None, + None, + ) + + +def dropout( + x: torch.Tensor, + p: float, + bias: Optional[torch.Tensor] = None, + activation: Optional[Activation] = None, +): + """ + Apply dropout on the input tensor. + Optionally add a bias, the computation will be fused. + """ + + assert p <= 1.0 and p >= 0.0 + + if p == 1.0: + return torch.zeros_like(x) + + # Micro optim, skip dropout + if p == 0.0: + x = x + bias if bias is not None else x + if activation is not None: + activation_fn = build_activation(activation) + return activation_fn(x) + return x + + # The normal triton enabled codepath + activation_index = get_triton_activation_index(activation) + return _dropout.apply( + x, + float(p), + bias, + activation_index, + bias is not None and bias.requires_grad, + ) + + +class FusedDropoutBias(torch.nn.Module): + """ + A layer which fuses the computation of Dropout(Activation(x)) + in a single GPU kernel + """ + + def __init__( + self, + p: float, + bias_shape: Optional[int], + activation: Optional[Activation] = None, + ) -> None: + super().__init__() + + self.p = float(p) + + assert ( + self.p < 1.0 + ), f"We don't want to drop all the values, most probably p={p} is not properly set" + + self.activation_type = activation + self.bias = ( + torch.zeros(bias_shape, requires_grad=True) + if bias_shape is not None + else None + ) + + self.activation = get_triton_activation_index(self.activation_type) + self.activation_pytorch = build_activation(self.activation_type) + + def init_weights(self, *args, **kwargs): + with torch.no_grad(): + if self.bias is not None: + self.bias.fill_(0.0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Convenience, catch a possible type or device mismatch + if self.bias is not None: + self.bias = self.bias.to(dtype=x.dtype, device=x.device) # type: ignore + + # Train/inference + p = self.p if self.training else 0.0 + + # This kernel is slower than pytorch for small buffers, bypassing it in that case + perf_check = x.shape[-1] > 512 + + # Catch a non-cuda setup, fallback to pytorch + if not x.is_cuda or not perf_check or p == 0.0: + x = x + self.bias if self.bias is not None else x + x = self.activation_pytorch(x) + return torch.nn.functional.dropout(x, p) if p > 0.0 else x + + # The normal, Triton-backed path + return _dropout.apply(x, p, self.bias, self.activation, True) diff --git a/pkgs/xformers/triton/fused_linear_layer.py b/pkgs/xformers/triton/fused_linear_layer.py new file mode 100644 index 0000000..998cdbd --- /dev/null +++ b/pkgs/xformers/triton/fused_linear_layer.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional + +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd + +from xformers.components.activations import Activation +from xformers.triton.k_activations import get_triton_activation_index +from xformers.triton.k_fused_matmul_bw import fused_matmul_backward +from xformers.triton.k_fused_matmul_fw import fused_matmul + + +class _fused_linear_triton(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx, + x, + weight, + bias, + activation, + trainable_weight, + trainable_bias, + save_activation_inputs, + ): + + # Kick the fused Triton kernel, handling bias and activation in one go + y, activation_inputs = fused_matmul( + x, weight, bias, activation, save_activation_inputs + ) + + ctx.activation = activation + ctx.trainable_weight = trainable_weight + ctx.trainable_bias = trainable_bias + + # Micro-optimization: saving these is not always needed (?) + if x.requires_grad or ctx.trainable_weight or ctx.trainable_bias: + ctx.save_for_backward(weight, activation_inputs, x) + + return y + + @staticmethod + @custom_bwd + def backward( + ctx: Any, grad_out: torch.Tensor + ) -> Any: # pragma: no cover # this is covered, but called directly from C++ + """ + Compute the derivative with respect to x, other tensors were not trainable inputs. + """ + (weight, activation_inputs, x) = ctx.saved_tensors + + grad_input, grad_weight, grad_bias = fused_matmul_backward( + grad_out=grad_out, + inputs=x, + act_in=activation_inputs, + weight=weight, + trainable_weight=ctx.trainable_weight, + trainable_bias=ctx.trainable_bias, + activation_grad=ctx.activation, + ) + + return grad_input, grad_weight, grad_bias, None, None, None, None, None, None + + +class FusedLinear(nn.Module): + """ + Handle a linear transform, like torch.nn.Linear_, and a given activation, in a single kernel. + The whole transform: is :math:`y = activation(xA^T + b)`. + + This is typically significantly faster than PyTorch while using fp16 and non-sigmoid activations, + as of September 2021. + + .. _torch.nn.Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + activation: Optional[Activation] = None, + **_, + ): + super().__init__() + self.weight = nn.Parameter( + torch.empty(out_features, in_features), requires_grad=True + ) + self.bias = ( + nn.Parameter(torch.empty(out_features), requires_grad=True) + if bias + else None + ) + + self._activation_index = get_triton_activation_index(activation) + self.reset_parameters() + + def reset_parameters(self) -> None: + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + torch.nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, x): + return _fused_linear_triton.apply( + x, + self.weight, + self.bias, + self._activation_index, + self.weight.requires_grad, + self.bias.requires_grad if self.bias is not None else False, + self.training and x.requires_grad and self._activation_index > 0, + ) diff --git a/pkgs/xformers/triton/k_activations.py b/pkgs/xformers/triton/k_activations.py new file mode 100644 index 0000000..e60b6ed --- /dev/null +++ b/pkgs/xformers/triton/k_activations.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import triton +import triton.language as tl + +from xformers.components import Activation + +_kAlpha = math.sqrt(2.0 / math.pi) + + +def get_triton_activation_index(activation: Optional[Activation]) -> int: + return ( + { + Activation.ReLU: 1, + Activation.LeakyReLU: 2, + Activation.GeLU: 3, + Activation.SquaredReLU: 4, + Activation.SmeLU: 5, + Activation.StarReLU: 6, + }[activation] + if activation is not None + else 0 + ) + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def cosh(x): + exp_x = tl.exp(x) + return (exp_x + 1.0 / exp_x) * 0.5 + + +# a Triton implementation of the most used activations +# See for instance http://arxiv.org/abs/1606.08415 for an overview + +# ReLU +@triton.jit +def relu(x): + """ + ReLU_ activation function + + .. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html + """ + return tl.where(x >= 0, x, 0.0) + + +@triton.jit +def relu_grad(x): + # ReLU is different from other activations + # in that it does not require the input to retrospectively compute its gradient + # here the input is the downstream gradient, and we return the upstream gradient directly + return tl.where(x >= 0, 1.0, 0.0) + + +@triton.jit +def squared_relu(x): + """ + Squared ReLU activation, as proposed in the Primer_ paper. + + .. _Primer: https://arxiv.org/abs/2109.08668 + """ + x_sq = x * x + return tl.where(x > 0.0, x_sq, 0.0) + + +@triton.jit +def squared_relu_grad(x): + return tl.where(x >= 0.0, 2 * x, 0.0) + + +@triton.jit +def star_relu(x): + """ + Star ReLU activation, as proposed in the "MetaFormer Baselines for Vision"_ paper. + + .. _ "MetaFormer Baselines for Vision": https://arxiv.org/pdf/2210.13452.pdf + """ + x_sq = x * x + return 0.8944 * tl.where(x > 0.0, x_sq, 0.0) - 0.4472 + + +@triton.jit +def star_relu_grad(x): + return tl.where(x >= 0.0, 1.7888 * x, 0.0) + + +# Leaky ReLU +@triton.jit +def leaky_relu(x): + """ + LeakyReLU_ activation + + .. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html + """ + return tl.where(x >= 0.0, x, 0.01 * x) + + +@triton.jit +def leaky_relu_grad(x): + return tl.where(x >= 0.0, 1.0, 0.01) + + +@triton.jit +def gelu(x): + """ + GeLU_ activation - Gaussian error linear unit + + .. _GeLU: https://arxiv.org/pdf/1606.08415.pdf + """ + return 0.5 * x * (1 + tanh(_kAlpha * (x + 0.044715 * x * x * x))) + + +@triton.jit +def gelu_grad(x): + # CREDITS: Fast implementation proposed in + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_bias_gelu.py#L30 + tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) + return 0.5 * x * ( + (1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x) + ) + 0.5 * (1 + tanh_out) + + +@triton.jit +def smelu(x): + """ + SmeLU_ activation - Smooth ReLU with beta=2.0 + + .. _SmeLU: https://arxiv.org/pdf/2202.06499.pdf + """ + beta = 2.0 + + relu = tl.where(x >= beta, x, 0.0) + return tl.where(tl.abs(x) <= beta, (x + beta) * (x + beta) / (4.0 * beta), relu) + + +@triton.jit +def smelu_grad(x): + beta = 2.0 + + relu_grad = tl.where(x >= beta, 1.0, 0.0) + return tl.where(tl.abs(x) <= beta, (beta + x) / (2.0 * beta), relu_grad) diff --git a/pkgs/xformers/triton/k_dropout.py b/pkgs/xformers/triton/k_dropout.py new file mode 100644 index 0000000..95f1c4e --- /dev/null +++ b/pkgs/xformers/triton/k_dropout.py @@ -0,0 +1,211 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: This is heavily inspired by the Triton dropout tutorial +# https://raw.githubusercontent.com/openai/triton/master/python/tutorials/04-low-memory-dropout.py + +import triton +import triton.language as tl + +from xformers.triton.k_activations import ( + gelu, + gelu_grad, + leaky_relu, + leaky_relu_grad, + relu, + relu_grad, + smelu, + smelu_grad, + squared_relu, + squared_relu_grad, +) + +_configs = [ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), +] + + +# fmt: off +@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]}) +@triton.autotune( + configs=_configs, + key=["M", "N", "is_fp16"], +) +@triton.jit +def k_dropout_fw( + Y, X, BIAS, SEEDS, + stride, + M, N, + p: tl.constexpr, + is_fp16: tl.constexpr, # autotune + ACTIVATION: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + SIZE_RAND_BLOCK: tl.constexpr, + USE_BIAS: tl.constexpr, +): + """ + Apply dropout on an input tensor + Y : Output (M, N) + X : Input (M, N) + BIAS (N,) + SEEDS (M,) + p : dropout probability + """ + # fmt: on + + row_id = tl.program_id(axis=0) + rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) + + col_id = tl.program_id(axis=1) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # pointers starting point + x_ptrs = X + rows[:, None] * stride + cols[None, :] + y_ptrs = Y + rows[:, None] * stride + cols[None, :] + + # good to go, start the layer computations + col_mask = cols[None, :] < N + p_scale = 1. / (1. - p) + if USE_BIAS: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=cols[None, :] < N, other=0.) + else: + bias = x_ptrs # will not be used + + block_mask = (rows[:, None] < M) & col_mask + x = tl.load(x_ptrs, mask=block_mask, other=0.0) + + # optionally apply a fused bias + if USE_BIAS: + x += bias + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION == 1: + x = relu(x) + elif ACTIVATION == 2: + x = leaky_relu(x) + elif ACTIVATION == 3: + x = gelu(x) + elif ACTIVATION == 4: + x = squared_relu(x) + elif ACTIVATION == 5: + x = smelu(x) + + # get the random keep mask + rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + seed_int = tl.load(SEEDS + col_id) + r = tl.rand(seed_int, rand_offsets) + keep_mask = r > p + + # prune and normalize in one go + keep = tl.view(keep_mask, x.shape) + output = tl.where(keep, (x * p_scale).to(x.dtype), 0.) + + tl.store(y_ptrs, output, mask=block_mask) # output + + +# fmt: off +@triton.heuristics({"SIZE_RAND_BLOCK": lambda args: args["BLOCK_N"] * args["BLOCK_M"]}) +@triton.autotune( + configs=_configs, + key=["M", "N", "is_fp16"], +) +@triton.jit +def k_dropout_bw( + GRAD_IN, GRAD_BIAS, GRAD_OUT, + INPUTS, BIAS, SEEDS, + stride_grad, stride_inputs, + M, N, + p: tl.constexpr, + is_fp16: tl.constexpr, # autotune + ACTIVATION: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, # heuristics + BLOCK_N: tl.constexpr, + SIZE_RAND_BLOCK: tl.constexpr, + TRAINABLE_BIAS: tl.constexpr, + USE_BIAS: tl.constexpr, +): + """ + Apply dropout on an input tensor + GRAD_OUT (M, N) + GRAD_BIAS (N,) + GRAD_IN (M, N) + BIAS (N,) + SEEDS (N,) + p : dropout probability + """ + # fmt: on + + row_id = tl.program_id(axis=0) + rows = row_id * BLOCK_M + tl.arange(0, BLOCK_M) + + col_id = tl.program_id(axis=1) + cols = col_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # pointers starting point + grad_out_ptrs = GRAD_OUT + rows[:, None] * stride_grad + cols[None, :] + grad_in_ptrs = GRAD_IN + rows[:, None] * stride_grad + cols[None, :] + input_ptrs = INPUTS + rows[:, None] * stride_inputs + cols[None, :] + + # now go over the tiles + grad_bias = tl.zeros((BLOCK_N,), dtype=tl.float32) + col_mask = cols[None, :] < N + p_scale = 1. / (1. - p) + + if USE_BIAS: + b_ptrs = BIAS + cols[None, :] + bias = tl.load(b_ptrs, mask=col_mask, other=0.) + + block_mask = (rows[:, None] < M) & col_mask + grad_out = tl.load(grad_out_ptrs, mask=block_mask, other=0.) + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION: + inputs = tl.load(input_ptrs, mask=block_mask, other=0.) + + # optionally apply a fused bias + if USE_BIAS: + inputs += bias + + if ACTIVATION == 1: + act_grad = relu_grad(inputs) + elif ACTIVATION == 2: + act_grad = leaky_relu_grad(inputs) + elif ACTIVATION == 3: + act_grad = gelu_grad(inputs) + elif ACTIVATION == 4: + act_grad = squared_relu_grad(inputs) + elif ACTIVATION == 5: + act_grad = smelu_grad(inputs) + + grad_out *= act_grad + + # randomly prune (and scale) the resulting buffer, possibly a no-op + # note that even if we did not save the mask from the FW pass, it is generated + # from the same seeds, so the same drop mask is applied here + rand_offsets = tl.arange(0, SIZE_RAND_BLOCK) + seed_int = tl.load(SEEDS + col_id) + r = tl.rand(seed_int, rand_offsets) + r = tl.view(r, grad_out.shape) + output = tl.where(r > p, (grad_out * p_scale).to(grad_out.dtype), 0.) + + # write-back + tl.store(grad_in_ptrs, output, mask=block_mask) + + # optionally accumulate the bias gradient + if TRAINABLE_BIAS: + grad_bias += tl.sum(output, axis=0) + + if TRAINABLE_BIAS: + grad_bias_ptr = GRAD_BIAS + row_id * N + cols + tl.store(grad_bias_ptr, grad_bias, mask=cols < N) diff --git a/pkgs/xformers/triton/k_fused_matmul_bw.py b/pkgs/xformers/triton/k_fused_matmul_bw.py new file mode 100644 index 0000000..dd1776f --- /dev/null +++ b/pkgs/xformers/triton/k_fused_matmul_bw.py @@ -0,0 +1,161 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from xformers.triton.k_activations import ( + gelu_grad, + leaky_relu_grad, + relu_grad, + smelu_grad, + squared_relu_grad, + star_relu_grad, +) + + +# fmt: off +@triton.autotune( + configs=[ + triton.Config({"BLOCK_N": 64}, num_stages=4, num_warps=2), + triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=2), + triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=4), + ], + key=["N"], +) +@triton.heuristics({ + 'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0, +}) +@triton.jit +def kernel_bw( + # Pointers to matrices + GRAD_ACT, GRAD_OUT, ACT_INPUTS, + # Matrix dimensions + N, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_gom, stride_aim, + # Meta-parameters + BLOCK_N: tl.constexpr, + EVEN_N: tl.constexpr, + ACTIVATION_GRAD: tl.constexpr, +): + # fmt: on + + """ + Go over all the activation inputs, compute the corresponding gradient + """ + + # this kernel is relatively simple in terms of scheduling: + # - per row (pid_m) + # - each program a given chunk on the col axis, + # since it's more effective memory and occupancy wise + pid_m, pid_n = tl.program_id(axis=0), tl.program_id(axis=1) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # the memory addresses of elements in the first block of + # A and W can be computed using numpy-style broadcasting + act_input_ptrs = ACT_INPUTS + pid_m * stride_aim + rn + + # compute the gradient which is related to this activation + if EVEN_N: + act_in = tl.load(act_input_ptrs) + else: + act_in = tl.load(act_input_ptrs, mask=rn < N, other=0.0) + + if ACTIVATION_GRAD == 1: + grad_act = relu_grad(act_in) + elif ACTIVATION_GRAD == 2: + grad_act = leaky_relu_grad(act_in) + elif ACTIVATION_GRAD == 3: + grad_act = gelu_grad(act_in) + elif ACTIVATION_GRAD == 4: + grad_act = squared_relu_grad(act_in) + elif ACTIVATION_GRAD == 5: + grad_act = smelu_grad(act_in) + elif ACTIVATION_GRAD == 6: + grad_act = star_relu_grad(act_in) + else: + grad_act = act_in + + # now read the incoming gradient, the backpropagated one is the multiple of both + grad_out_ptrs = GRAD_OUT + pid_m * stride_gom + rn + if EVEN_N: + grad_out = tl.load(grad_out_ptrs) + else: + grad_out = tl.load(grad_out_ptrs, mask=rn < N) + + grad_act *= grad_out + + # write back result + grad_act_ptrs = GRAD_ACT + pid_m * stride_gom + rn + tl.store(grad_act_ptrs, grad_act, mask=rn < N) + + +def fused_matmul_backward( + grad_out: torch.Tensor, + inputs: torch.Tensor, + act_in: Optional[torch.Tensor], + weight: torch.Tensor, + trainable_weight: bool, + trainable_bias: bool, + activation_grad: int = 0, +): + """ + Compute grad_in = activation^-1(grad_out) @ weight.transpose() + + .. note: The weight buffer is transposed on the fly + .. note: Activation gradient needs to be a Triton kernel + """ + + # Make sure that we don't have to handle the stride over cols + if not grad_out.is_contiguous(): + grad_out = grad_out.contiguous() + + grad_out_ = grad_out if grad_out.ndim == 2 else grad_out.flatten(0, -2) + inputs_ = inputs if inputs.ndim == 2 else inputs.flatten(0, -2) + + assert grad_out_.shape[1] == weight.shape[0], "Incompatible dimensions in between grad_out and weight" + + M, N = grad_out_.shape + N, _ = weight.shape + + # Compute the gradient for the activation + if activation_grad > 0: + grad_act = torch.empty_like(grad_out_) + + # Some activations do not require their inputs to + # know of their grad, the downstream grad is enough + if act_in is None: + act_in = grad_out_ + + grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) # noqa + + # fmt: off + kernel_bw[grid]( + grad_act, grad_out_, act_in, # data ptrs + N, # shapes + grad_act.stride(0), act_in.stride(0), # strides + ACTIVATION_GRAD=activation_grad, # optional fused activation + ) + # fmt: on + + # Backpropagation going up, the reference gradient is now + # just before the activation + grad_out_ = grad_act + + # The following ops can also be handled by pytorch + grad_in = triton.ops.matmul(grad_out_, weight) + grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None + grad_bias = torch.sum(grad_out_, dim=0) if trainable_bias else None + + return grad_in.reshape_as(inputs), grad_weight, grad_bias diff --git a/pkgs/xformers/triton/k_fused_matmul_fw.py b/pkgs/xformers/triton/k_fused_matmul_fw.py new file mode 100644 index 0000000..0f0e381 --- /dev/null +++ b/pkgs/xformers/triton/k_fused_matmul_fw.py @@ -0,0 +1,253 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from xformers.triton.k_activations import ( + gelu, + leaky_relu, + relu, + smelu, + squared_relu, + star_relu, +) + +# CREDITS: Initially inspired by the Triton tutorial on matrix multiplications + + +def get_configs(block_k): + return [ + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": block_k}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": block_k}, + num_stages=4, + num_warps=2, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + triton.Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": block_k}, + num_stages=3, + num_warps=4, + ), + # Fails on small GPUS + # triton.Config( + # {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": block_k}, + # num_stages=3, + # num_warps=8, + # ), + # triton.Config( + # {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": block_k}, + # num_stages=3, + # num_warps=8, + # ), + ] + + +# fmt: off +@triton.autotune( + configs=[c for block_k in [32, 64] for c in get_configs(block_k)], + key=["M", "N", "K"], +) +@triton.heuristics({ + 'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0, +}) +@triton.jit +def kernel_fma( + # Pointers to matrices + OUT, ACT_INPUTS, INPUT, WEIGHT, bias, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. stride_am is how much to increase a_ptr + # by to get the element one row down (A has M rows) + stride_om, stride_im, + stride_wn, + # Meta-parameters + BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, + BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + EVEN_N: tl.constexpr, + BIAS: tl.constexpr, + SAVE_ACT_INPUTS: tl.constexpr, + ACTIVATION: tl.constexpr, + is_fp16: tl.constexpr, # autotune +): + # fmt: on + + """ + Kernel for computing Out = activation(A x W + C) + + - Input has shape (M, K) + - Weight has shape (K, N) + - Bias has shape (N,) + - Output has shape (M, N) + - ActInputs (optional) has shape (M, N) + + 'ActInputs' optionally saves the A x W + C intermediate for backward computations + + This kernel will consolidate over K + """ + + # programs are grouped together to improve L2 hit rate + # the logic is that we'll consolidate over K. If the programs were not grouped, + # then multiple cols/rows in the result would end up pulling in the same row and lines + # from the inputs. By grouping the computation we ensure some data reuse, which the hardware + # covers via the L2 cache + pid = tl.program_id(axis=0) + + num_pid_m = tl.cdiv(M, BLOCK_M) # number of program ids along the M axis + num_pid_n = tl.cdiv(N, BLOCK_N) # number of programs ids along the N axis + num_pid_in_group = GROUP_M * num_pid_n # number of programs in group + group_id = pid // num_pid_in_group # id of the group this program is in + first_pid_m = group_id * GROUP_M # row-id of the first program in the group + GROUP_M = min( + num_pid_m - first_pid_m, GROUP_M + ) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller + + # *within groups*, programs are ordered in a column-major order + # row-id /col-id of the program in the *launch grid* + pid_m = first_pid_m + (pid % GROUP_M) + pid_n = (pid % num_pid_in_group) // GROUP_M + + # now compute the block that each program will go through + # rm (resp. rn) denotes a range of indices + # for rows (resp. col) of C + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + rk = tl.arange(0, BLOCK_K) + + # the memory addresses of elements can follow numpy broadcasting + input_ptrs = INPUT + rm[:, None] * stride_im + weight_ptrs = WEIGHT + rn[None, :] * stride_wn + + # initialize and iteratively update accumulator + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + + if BIAS: + if EVEN_N: + bias = tl.load(bias + rn).to(tl.float32) + else: + bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) + acc += bias[None, :] + + # block level matrix multiplication. + # We fetch a block memory block from both inputs, matmul and accumulate, then repeat + mask_rn = rn < N + mask_rm = rm < M + + for i in range(0, K, BLOCK_K): + rk = tl.arange(0, BLOCK_K) + i + a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0) + w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0) + + acc += tl.dot(a, w) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + # optional: save the activation inputs + if SAVE_ACT_INPUTS: + act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :] + tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :]) + + # optional: fused activation (while the data is in shared memory) + if ACTIVATION == 1: + acc = relu(acc) + elif ACTIVATION == 2: + acc = leaky_relu(acc) + elif ACTIVATION == 3: + acc = gelu(acc) + elif ACTIVATION == 4: + acc = squared_relu(acc) + elif ACTIVATION == 5: + acc = smelu(acc) + elif ACTIVATION == 6: + acc = star_relu(acc) + + # write back result + out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :] + tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :]) + + +# Activation needs to be a triton kernel +def fused_matmul( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + activation=0, + save_act_inputs: bool = False +): + """ + Compute e = activation(x @ weight + bias). + This wrapper kicks the `kernel_fma` Triton kernel + """ + + if not x.is_contiguous(): + x = x.contiguous() + + x_ = x if x.ndim == 2 else x.flatten(0, -2) + + assert ( + x_.shape[1] == weight.shape[1] + ), f"Incompatible dimensions in between inputs and weight, {x_.shape} - {weight.shape}" + assert bias is None or bias.is_contiguous() + assert ( + bias is None or bias.shape[0] == weight.shape[0] + ), "Incompatible dimensions in between weight and bias" + assert weight.is_contiguous() + + M, K = x_.shape + N, K = weight.shape + + outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) + act_inputs = torch.empty_like(outputs) if save_act_inputs else x # will not be used in that case + + # 1D launch kernel where each block gets its own program. + grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa + + # fmt: off + kernel_fma[grid]( + outputs, act_inputs, x_, weight, # data ptrs + bias if bias is not None else x, # auto skip bias if not present + M, N, K, # shapes + outputs.stride(0), x_.stride(0), # strides + weight.stride(0), + ACTIVATION=activation, # optional fused activation + BIAS=bias is not None, # optional fused bias + GROUP_M=8, # speed optimization: group the programs + SAVE_ACT_INPUTS=save_act_inputs, + is_fp16=x_.dtype == torch.float16 + ) + # fmt: on + + outputs = outputs if x.ndim == 2 else outputs.reshape(*x.shape[:-1], N) + + return outputs, act_inputs if save_act_inputs else None diff --git a/pkgs/xformers/triton/k_layer_norm.py b/pkgs/xformers/triton/k_layer_norm.py new file mode 100644 index 0000000..cc1c804 --- /dev/null +++ b/pkgs/xformers/triton/k_layer_norm.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: This comes almost as-is from the Triton layer norm tutorial +# https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py + + +import triton +import triton.language as tl + + +# fmt: off +@triton.jit +def layer_norm_fw(X, Y, W, B, M, V, stride, N, eps, affine: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + # fmt: on + """ + Fused layernorm kernel over a 3d tensor. + The layer norm is applied over the last dimension. + + Compute + y = (x - E(x))/(sqrt(var(x) + epsilon)) * gamma + beta + """ + + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N + + # Move to this row + x_ptrs = X + row * stride + cols + x = tl.load(x_ptrs, mask=mask, other=0.0).to(tl.float32) + + # Compute mean and variance + mean = tl.sum(x, axis=0) / N + x_zm = tl.where(mask, x - mean, 0.0) + tl.store(M + row, mean) + + x_var = tl.sum(x_zm * x_zm, axis=0) / N + rstd = 1.0 / tl.sqrt(x_var + eps) + + # Normalize, optionally affine + y = x_zm * rstd + tl.store(V + row, rstd) + + mask = cols < N + if affine: + w = tl.load(W + cols, mask=mask, other=1.0) + b = tl.load(B + cols, mask=mask, other=0.0) + y = y * w + b + + y_ptrs = Y + row * stride + cols + tl.store(y_ptrs, y, mask=mask) + + +# Backward pass (DX + partial DW + partial DB) +# fmt: off +@triton.jit +def layer_norm_bwd_dx_fused( + DX, DY, DW, DB, + X, W, M, V, + Lock, stride, N, + # META-parameters + affine: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + # fmt: on + + # position of elements processed by this program + row = tl.program_id(0) + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N + + # offset data pointers to start at the row of interest + x_ptrs = X + row * stride + cols + dy_ptrs = DY + row * stride + cols + + # load data to SRAM + x = tl.load(x_ptrs, mask=mask, other=0) + dy = tl.load(dy_ptrs, mask=mask, other=0) + mean = tl.load(M + row) + rstd = tl.load(V + row) + + # compute dx + xhat = (x - mean) * rstd + + if affine: + w = tl.load(W + cols, mask=mask, other=0) + wdy = w * dy + else: + wdy = dy + + xhat = tl.where(mask, xhat, 0.) + wdy = tl.where(mask, wdy, 0.) + mean1 = tl.sum(xhat * wdy, axis=0) / N + mean2 = tl.sum(wdy, axis=0) / N + dx = (wdy - (xhat * mean1 + mean2)) * rstd + + # write-back dx + cols = tl.arange(0, BLOCK_SIZE_N) + mask = cols < N # re-materialize the mask to save registers + dx_ptrs = DX + row * stride + cols + tl.store(dx_ptrs, dx, mask=mask) + + if affine: + # accumulate partial sums for dw/db + partial_dw = (dy * xhat).to(w.dtype) + partial_db = dy.to(w.dtype) + + # offset locks and weight/bias gradient pointer + # each kernel instance accumulates partial sums for + # DW and DB into one of GROUP_SIZE_M independent buffers + # these buffers stay in the L2, which allow this kernel + # to be fast + lock_id = row % GROUP_SIZE_M + Lock += lock_id + Count = Lock + GROUP_SIZE_M + + # - wait for a lock on the accumulated dw/db + while tl.atomic_cas(Lock, 0, 1) == 1: + pass + count = tl.load(Count) + + # - we got the lock, accumulate this kernel's results with + # the stored values. + dw_ptrs = DW + lock_id * N + cols + db_ptrs = DB + lock_id * N + cols + + if count == 0: + # first store doesn't accumulate + tl.atomic_xchg(Count, 1) + else: + partial_dw += tl.load(dw_ptrs, mask=mask, other=0.) + partial_db += tl.load(db_ptrs, mask=mask, other=0.) + + tl.store(dw_ptrs, partial_dw, mask=mask) + tl.store(db_ptrs, partial_db, mask=mask) + + # release lock + tl.atomic_xchg(Lock, 0) + + +# Backward pass (total DW + total DB) +# fmt: off +@triton.jit +def layer_norm_bwd_dwdb( + DW, DB, FINAL_DW, FINAL_DB, + M, N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr +): + # fmt: on + + pid = tl.program_id(0) + + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cols = cols < N + + dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for i in range(0, M, BLOCK_SIZE_M): + rows = i + tl.arange(0, BLOCK_SIZE_M) + offs = rows[:, None] * N + cols[None, :] + mask_rm = rows < M + + dw += tl.load(DW + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0) + db += tl.load(DB + offs, mask=mask_rm[:, None] & mask_cols[None, :], other=0.0) + + sum_dw = tl.sum(dw, axis=0) + sum_db = tl.sum(db, axis=0) + + cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask_cols = cols < N + + tl.store(FINAL_DW + cols, sum_dw, mask=mask_cols) + tl.store(FINAL_DB + cols, sum_db, mask=mask_cols) diff --git a/pkgs/xformers/triton/k_softmax.py b/pkgs/xformers/triton/k_softmax.py new file mode 100644 index 0000000..4dc99de --- /dev/null +++ b/pkgs/xformers/triton/k_softmax.py @@ -0,0 +1,175 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import triton +import triton.language as tl + +# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/ +# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html + + +def get_depth(args): + return triton.next_power_of_2(args["K"]) + + +# autotune: Triton will test out these configurations, and automatically pick the fastest one. +# heuristic: add arguments to the kernel call automatically given some heuristics. These arguments are passed in "meta" +# fmt: off +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + triton.Config({}, num_warps=32), + ], + key=["K"], +) +@triton.heuristics(values={"depth": get_depth}) +@triton.jit +def _softmax( + Y, X, M, + stride_ym, stride_yn, + stride_xm, stride_xn, + stride_mn, + K, + # Meta-params + depth: tl.constexpr, + causal: tl.constexpr, + use_mask: tl.constexpr, + log: tl.constexpr, +): + # fmt: om + + """ + Fused softmax kernel over a 3d tensor. + The softmax is applied over the last dimension, meaning that this is equivalent to torch.softmax(tensor, dim=-1) + + Note, if the last dimension is large, say 128K elements, the kernel compile time can shot up to many minutes when + the kernel is run for the first time. + """ + + m = tl.program_id(0) + n = tl.program_id(1) + + # col indices + k = tl.arange(0, depth) + + # the memory address of all the elements that we want to load can be computed as follows + x_ptrs = X + m * stride_xm + n * stride_xn + k + + # load input data; pad out-of-bounds elements with 0 + io_mask = k < K + + # Causal - 1: skip on the loads directly + if causal: + io_mask = io_mask & (k <= n) + + x = tl.load(x_ptrs, mask=io_mask, other=float("-inf")).to(tl.float32) + + # Causal - 2: enforce correctness over a couple of misloaded values + if causal: + off = float("-inf") + off = off.to(x.dtype) # type: ignore + x = tl.where(k > n, off, x) + + if use_mask: + mask_ptrs = M + n * stride_mn + k + add_mask = tl.load(mask_ptrs, io_mask, other=float("-inf")).to(tl.float32) + x += add_mask + + # compute numerically-stable softmax + z = x - tl.max(x, axis=0) + num = tl.exp(z) + denom = tl.sum(num, axis=0) + + if log: + y = z - tl.log(denom) + else: + y = num / denom + + # write back to Y. + # we only write once, hence the "fused" softmax naming + y_ptrs = Y + m * stride_ym + n * stride_yn + k + + # technically we could write only the lower triangular matrix in the causal case + # but this is deemed to error prone + tl.store(y_ptrs, y, mask=k < K) + + +# fmt: off +@triton.autotune( + configs=[ + triton.Config({}, num_warps=1), + triton.Config({}, num_warps=2), + triton.Config({}, num_warps=4), + triton.Config({}, num_warps=8), + triton.Config({}, num_warps=16), + ], + key=["K"], +) +@triton.jit +def _softmax_backward( + GradIn, GradOut, Out, + stride_bm, stride_bn, + stride_gm, stride_gn, + stride_om, stride_on, + K, + # meta-params + depth: tl.constexpr, + causal: tl.constexpr, + log: tl.constexpr, +): + # fmt: on + + """ + Compute the softmax gradients. + ..Note: Not autotuning for now because this would lead to broken accumulated gradients + """ + + m = tl.program_id(0) + n = tl.program_id(1) + + # col indices + k = tl.arange(0, depth) + + # the memory address of all the elements that we want to load can be computed as follows + grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k + out_ptrs = Out + m * stride_om + n * stride_on + k + + # load input data; pad out-of-bounds elements with 0 + io_mask = k < K + + # Causal - 1: skip on the loads directly + if causal: + io_mask = io_mask & (k <= n) + + g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0)).to(tl.float32) + o = tl.load(out_ptrs, mask=io_mask, other=float(0)).to(tl.float32) + + # Causal - 2: enforce correctness over a couple of misloaded values + if causal: + zero = float(0) + zero = zero.to(g.dtype) # type: ignore + g = tl.where(k > n, zero, g) + o = tl.where(k > n, zero, o) + + if log: + s = tl.sum(g, 0) + grad_in = g - tl.exp(o) * s + else: + # Step 1: Compute the intermediate sum used for the gradient + s = tl.sum(g * o, 0) + + # Step 2: Compute the gradients + grad_in = o * (g - s) + + # write back to the input gradients + # technically we could write only the lower triangular matrix in the causal case + # but this is deemed to error prone + grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k + tl.store(grad_in_ptrs, grad_in, mask=k < K) diff --git a/pkgs/xformers/triton/k_sum.py b/pkgs/xformers/triton/k_sum.py new file mode 100644 index 0000000..339c368 --- /dev/null +++ b/pkgs/xformers/triton/k_sum.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import triton +import triton.language as tl + + +# fmt: off +@triton.jit +def k_sum_0( + Y, X, + stride_xm, + M, N, + is_fp16, + # META-params + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # fmt: om + + """ + Sum a 2d tensor over the first (strided) dimension. + This extracts some speed through a parallel sum across the second dimension + """ + + # partial row indices. We'll reduce over this dimension + m = tl.arange(0, BLOCK_M) + + # To get some extra parallelization, we handle several columns in the same thread block + rn = tl.program_id(axis=0) * BLOCK_N + tl.arange(0, BLOCK_N) + + # the memory address of all the elements that we want to load can be computed as follows + x_ptrs = X + m[:, None] * stride_xm + rn[None, :] + x_sum = tl.zeros((BLOCK_N,), dtype=tl.float32) + + tiles = M // BLOCK_M + if M % BLOCK_M > 0: + tiles += 1 + + col_mask = (rn[None, :] < N) + + for _ in range(tiles): + # load input data; pad out-of-bounds elements with 0 + # NOTE: make sure to accumulate in fp32 to prevent a trivial overflow + mask = (m[:, None] < M) & col_mask + x = tl.load(x_ptrs, mask=mask, other=0.0) + x_sum += tl.sum(x, 0) + + # move the load pointer + x_ptrs += BLOCK_M * stride_xm + m += BLOCK_M # update the mask check + + tl.store(Y + rn, x_sum, mask=rn < N) diff --git a/pkgs/xformers/triton/layer_norm.py b/pkgs/xformers/triton/layer_norm.py new file mode 100644 index 0000000..cec4cc9 --- /dev/null +++ b/pkgs/xformers/triton/layer_norm.py @@ -0,0 +1,236 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# CREDITS: the underlying kernel comes straight from the Triton tutorials +# see https://github.com/openai/triton/blob/master/python/tutorials/05-layer-norm.py + +import logging +from typing import Optional + +import torch +import torch.nn as nn +import triton +from torch.cuda.amp import custom_bwd, custom_fwd + +from xformers.triton.k_layer_norm import ( + layer_norm_bwd_dwdb, + layer_norm_bwd_dx_fused, + layer_norm_fw, +) + +logger = logging.getLogger("xformers") + + +_triton_layernorm_fp16_enabled = False # NOTE: PyTorch keeps layernorm as fp32 +_triton_registered_warnings = False + + +class _LayerNorm(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16 if _triton_layernorm_fp16_enabled else None) + def forward(ctx, x, weight, bias, eps): + # catch eps being too small if the tensors are fp16 + if x.dtype == torch.float16: + eps = max(eps, 1.6e-5) + + # allocate output + y = torch.empty_like(x) + + # reshape input data into 2D tensor + x_arg = x.reshape(-1, x.shape[-1]) + M, N = x_arg.shape + + # allocate mean and std, they'll be used in the backward pass + mean = torch.empty((M,), dtype=torch.float32, device="cuda") + rstd = torch.empty((M,), dtype=torch.float32, device="cuda") + + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + if not x_arg.is_contiguous() or not y.is_contiguous(): + global _triton_registered_warnings + if not _triton_registered_warnings: + logger.warning( + "Non-contiguous input tensor found. Making it contiguous," + + " but could have perf or trainer implications" + ) + + _triton_registered_warnings = True + + x_arg = x_arg.contiguous() + y = y.contiguous() + + # heuristics for number of warps. + num_warps = min(max(BLOCK_SIZE_N // 256, 1), 16) + + # enqueue kernel + # fmt: off + layer_norm_fw[(M,)]( + x_arg, y, weight, bias, mean, rstd, + x_arg.stride(0), + N, + eps, + num_warps=num_warps, + BLOCK_SIZE_N=BLOCK_SIZE_N, + affine=weight is not None + ) + # fmt: on + + ctx.save_for_backward(x, mean, rstd, weight) + ctx.BLOCK_SIZE_N = BLOCK_SIZE_N + ctx.num_warps = num_warps + + return y.reshape_as(x) + + @staticmethod + @custom_bwd + def backward( + ctx, dy + ): # pragma: no cover # this is covered, but called directly from C++ + x, mean, rstd, weight = ctx.saved_tensors + + # flatten the batch dimension, if any. + # We're interested in 'samples' x norm_dimension + x = x.reshape(-1, x.size(-1)) + M, N = x.size() + + # heuristics for amount of parallel reduction stream for DG/DB + GROUP_SIZE_M = 32 + if N <= 8192: + GROUP_SIZE_M = 64 + if N <= 4096: + GROUP_SIZE_M = 96 + if N <= 2048: + GROUP_SIZE_M = 128 + if N <= 1024: + GROUP_SIZE_M = 256 + + if dy.dtype == torch.float32: + GROUP_SIZE_M = GROUP_SIZE_M // 2 + + # allocate output + locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device="cuda") + t_args = {"dtype": x.dtype, "device": x.device} + _dw = torch.empty((GROUP_SIZE_M, x.size(-1)), **t_args) + _db = torch.empty_like(_dw) + dw = torch.empty((x.size(-1),), **t_args) + db = torch.empty_like(dw) + dy = dy.contiguous() + dx = torch.empty_like(dy) + + # Check the tensor shapes and layouts + # we suppose in the kernel that they have the same size and are contiguous + assert ( + dy.numel() == x.numel() + ), "Something is wrong in the backward graph, possibly because of an inplace operation after the layernorm" + + # enqueue kernel using forward pass heuristics + # also compute partial sums for DW and DB + num_warps = min(max(ctx.BLOCK_SIZE_N // 256, 1), 16) + + # fmt: off + layer_norm_bwd_dx_fused[(M,)]( + dx, dy, _dw, _db, x, + weight if weight is not None else x, + mean, rstd, + locks, + x.stride(0), + N, + affine=weight is not None, + GROUP_SIZE_M=GROUP_SIZE_M, + BLOCK_SIZE_N=ctx.BLOCK_SIZE_N, + num_warps=num_warps + ) + # fmt: on + + def grid(meta): + return [triton.cdiv(N, meta["BLOCK_SIZE_N"])] + + # accumulate partial sums in separate kernel + # fmt: off + layer_norm_bwd_dwdb[grid]( + _dw, _db, dw, db, + GROUP_SIZE_M, + N, + BLOCK_SIZE_M=32, + BLOCK_SIZE_N=64 + ) + # fmt: on + + dx = dx.reshape_as(dy) + return dx, dw, db, None + + +class FusedLayerNorm(nn.Module): + """ + Handle a layer normalization, like torch.nn.LayerNorm_. + + This implementation should be measurably faster than the default PyTorch layernorm (as of PyTorch 1.9), + both for training and inference worloads. + + .. NOTE: Computations under Torch AMP are kept as float32 by default, one can change this to be float16 + by setting the flag `xformers.triton.k_layer_norm._triton_layernorm_fp16_enabled = True` + + .. _torch.nn.LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html + + """ + + def __init__(self, normalized_shape, affine=True, eps=1e-06): + super().__init__() + if affine: + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + else: + self.weight = self.bias = None + self.epsilon = eps + + def forward(self, x): + return layer_norm(x, self.weight, self.bias, self.epsilon) + + def init_weights(self, *args, **kwargs): + with torch.no_grad(): + if self.weight is not None: + self.weight.fill_(1.0) + + if self.bias is not None: + self.bias.fill_(0.0) + + +def layer_norm( + x: torch.Tensor, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-06, +) -> torch.Tensor: + + global _triton_registered_warnings + + r"""Applies normalization over a mini batch of inputs""" + + try: + if ( + not _triton_registered_warnings + and torch.cuda.is_available() + and x.is_cuda + and weight is not None + and bias is not None + ): + return _LayerNorm.apply(x, weight, bias, eps) + except RuntimeError as e: + # Catch cases where the current GPU does not have enough registers to hold a full tensor line + # fallback to PyTorch's implementation, which streams the tensor in and out + _triton_registered_warnings = True + logger.warning( + "Triton layernorm kernel register spillover or invalid image caught. " + "Deactivating this kernel, please file an issue in the xFormers repository" + ) + logger.warning(e) + + return torch.nn.functional.layer_norm( + x, [x.shape[-1]], weight=weight, bias=bias, eps=eps + ) diff --git a/pkgs/xformers/triton/softmax.py b/pkgs/xformers/triton/softmax.py new file mode 100644 index 0000000..0de7be9 --- /dev/null +++ b/pkgs/xformers/triton/softmax.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from typing import Optional + +import torch +import triton +from torch.cuda.amp import custom_bwd, custom_fwd + +from xformers.triton.k_softmax import _softmax, _softmax_backward + +# CREDITS: This is adapted from the vanilla Triton example. See https://openai.com/blog/triton/ +# and https://triton-lang.org/getting-started/tutorials/02-fused-softmax.html + + +logger = logging.getLogger("xformers") + + +_triton_softmax_fp16_enabled = False # NOTE: PyTorch keeps softmax as fp32 +_triton_registered_warnings = False + + +# Helper to handle the SPMD launch grid and error cases +class _softmax_triton(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16 if _triton_softmax_fp16_enabled else None) + def forward(ctx, x, mask, log_outputs, causal): + """ + Fused softmax implementation, using the Triton programming model. + This only supports a reduction over the last dimension for now + """ + + # Handle 2D/3D tensors + x_ = x.unsqueeze(0) if x.ndim == 2 else x + x_ = x_.flatten(0, -3) + + if not x_.is_contiguous(): + x_ = x_.contiguous() + + y = torch.empty_like(x_) + assert ( + y.stride(2) == 1 and x_.stride(2) == 1 + ), f"{x.shape} - {x_.shape} - {x_.stride()}" + + # SPMD launch grid + grid_2d = ( + x_.shape[0], + x_.shape[1], + ) + + # enqueue GPU kernel + use_mask = True + if mask is None: + # placeholder, will not be used + mask = x_ + use_mask = False + else: + # Make sure that the mask is binary + assert mask.dtype == x.dtype, "An additive mask is requested" + + _softmax[grid_2d]( + y, + x_, + mask, + y.stride(0), + y.stride(1), + x_.stride(0), + x_.stride(1), + mask.stride(0), + x_.shape[2], + log=log_outputs, + use_mask=use_mask, + causal=causal, + ) + + ctx.save_for_backward(y) + ctx.log_outputs = log_outputs + ctx.causal = causal + return y.reshape_as(x) + + @staticmethod + @custom_bwd + def backward( + ctx, grad_out + ): # pragma: no cover # this is covered, but called directly from C++ + (out,) = ctx.saved_tensors + + # Handle 2D/3D tensors + grad_out_ = grad_out.unsqueeze(0) if grad_out.ndim == 2 else grad_out + grad_out_ = grad_out_.flatten(0, -3) + + # SPMD launch grid + grid_2d = ( + grad_out_.shape[0], + grad_out_.shape[1], + ) + + depth = triton.next_power_of_2(grad_out_.shape[2]) + grad_in = torch.empty_like( + out + ) # torch.zeros is measurably slower, we'll zero out in the kernel + + # Make sure that the tensor are contiguous + grad_in, grad_out, out = map(lambda x: x.contiguous(), [grad_in, grad_out, out]) + + # fmt: off + _softmax_backward[grid_2d]( + grad_in, grad_out_, out, + grad_in.stride(0), grad_in.stride(1), + grad_out_.stride(0), grad_out_.stride(1), + out.stride(0), out.stride(1), + out.shape[2], + depth=depth, + log=ctx.log_outputs, + causal=ctx.causal + ) + # fmt: on + return grad_in.reshape_as(grad_out), None, None, None + + +def softmax( + x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False +) -> torch.Tensor: + r"""Applies the Softmax function to an 3-dimensional input Tensor + rescaling them so that the elements of the n-dimensional output Tensor + lie in the range [0,1] and sum to 1. + + Softmax is defined as: + + .. math:: + \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} + + .. warning: softmax is computed on the last dimension of the input tensor. + + + Args: + x: input tensor. + mask: optional mask, its application will be fused to the softmax computation if triton is used + causal: optional performance optimization, if triton is used and the attention is causal + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [0, 1] and sum to 1 + """ + return _softmax_dispatch(x, log=False, mask=mask, causal=causal) + + +def log_softmax( + x: torch.Tensor, mask: Optional[torch.Tensor] = None, causal: bool = False +) -> torch.Tensor: + r"""Applies the :math:`\log(\text{Softmax}(x))` function to an 3-dimensional + input Tensor. The LogSoftmax formulation can be simplified as: + + .. math:: + \text{LogSoftmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right) + + Args: + x: input tensor. + + Returns: + a Tensor of the same dimension and shape as the input with + values in the range [-inf, 0) + """ + return _softmax_dispatch(x, log=True, mask=mask, causal=causal) + + +def _softmax_dispatch( + x: torch.Tensor, log: bool, mask: Optional[torch.Tensor], causal: bool = False +) -> torch.Tensor: + # Triton is used if + # - CUDA + # - there's enough data to make it faster than pytorch. This could change over time, Triton is improving + # - there was no previous failure + + global _triton_registered_warnings + + try: + if torch.cuda.is_available() and x.is_cuda and not _triton_registered_warnings: + return _softmax_triton.apply(x, mask, log, causal) + except RuntimeError as e: + # Catch cases where the current GPU does not have enough registers to hold a full tensor line + # fallback to PyTorch's implementation, which streams the tensor in and out + _triton_registered_warnings = True + logger.warning( + "Triton softmax kernel register spillover or invalid image caught." + "Deactivating this kernel, please file an issue int the xFormers repository" + ) + logger.warning(e) + + if mask is not None: + x = x + mask + + if causal: + x = x + torch.triu(torch.full_like(x, float("-inf")), diagonal=1) + + if log: + return torch.log_softmax(x, dim=-1) + else: + return torch.softmax(x, dim=-1) diff --git a/pkgs/xformers/triton/sum_strided.py b/pkgs/xformers/triton/sum_strided.py new file mode 100644 index 0000000..ed23384 --- /dev/null +++ b/pkgs/xformers/triton/sum_strided.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import triton + +from xformers.triton.k_sum import k_sum_0 + + +def sum_2d_dim_0(x: torch.Tensor): + """ + Sum a 2D tensor across the first dimension + """ + + out = torch.empty(x.shape[1], device=x.device, dtype=x.dtype) + + assert ( + x.ndim == 2 + ), "This is a very specific kernel, only for 2-dim tensors and summing along dim 0" + M, N = x.shape + + # This kernel is not competitive for these sizes + if M > 2048 or M < 8: + return x.sum(dim=0) + + assert ( + M >= 4 + ), "This is a very specific kernel, requires the reduction dimension to be bigger than 4" + + assert x.stride(1) == 1, ( + "We're expecting x to be contiguous along dim 1, and non contiguous along dim 0.\n" + " You would probably be better served with torch.sum()" + ) + + BLOCK_M = min(triton.next_power_of_2(M), 2048) + BLOCK_N = 32 + if BLOCK_M > 256: + BLOCK_N = 16 + if BLOCK_M > 1024: + BLOCK_N = 8 + + def grid(meta): + return (triton.cdiv(N, meta["BLOCK_N"]),) + + # fmt: off + k_sum_0[grid]( + out, x, + x.stride(0), + M, N, + x.dtype == torch.float16, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + num_stages=4, + ) + # fmt: on + + return out diff --git a/pkgs/xformers/triton/utils.py b/pkgs/xformers/triton/utils.py new file mode 100644 index 0000000..7c2fd70 --- /dev/null +++ b/pkgs/xformers/triton/utils.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +import torch + +logger = logging.getLogger("xformers") + + +_gpu_is_old: Optional[bool] = None + + +def gpu_capabilities_older_than_70() -> bool: + """Return True if the GPU's compute capability is older than SM70.""" + global _gpu_is_old + if _gpu_is_old is None: + for i in range(torch.cuda.device_count()): + major, _ = torch.cuda.get_device_capability(f"cuda:{i}") + if major < 7: + _gpu_is_old = True + if _gpu_is_old is None: + _gpu_is_old = False + return _gpu_is_old + + +SUPPORTED_CUDA_DEVICES = ["V100", "A100", "T4"] + + +def get_current_cuda_device(): + current_device = str(torch.cuda.get_device_properties(torch.cuda.current_device())) + for device_str in SUPPORTED_CUDA_DEVICES: + if current_device.find(device_str) > 0: + return device_str + + logger.warning("Unsupported device, Triton code generation may fail") + return "P100" # default to an old GPU diff --git a/pkgs/xformers/triton/vararg_kernel.py b/pkgs/xformers/triton/vararg_kernel.py new file mode 100644 index 0000000..f994362 --- /dev/null +++ b/pkgs/xformers/triton/vararg_kernel.py @@ -0,0 +1,173 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import copy +import functools +import linecache +import sys +from typing import Any, Dict, List + +import triton + + +class _ForLoopUnroller(ast.NodeTransformer): + def __init__(self, target, inline_variables, loop_iter): + self.loop_iter = loop_iter + self.target = target + self.inline_variables = inline_variables + + def visit_Name(self, node): + if node.id != self.target: + return node + return ast.Name(str(self.loop_iter)) + + def visit_Subscript(self, node): + # Pattern-matching `value[slice]` + if ( + isinstance(node.slice, ast.Name) + and node.slice.id == self.target + and isinstance(node.value, ast.Name) + and node.value.id in self.inline_variables + ): + return ast.Name(f"{node.value.id}{self.loop_iter}") + return node + + +class _VisitorUnrollKernel(ast.NodeTransformer): + def __init__(self, N): + self.inline_variables = set() + self.N = N + + def visit_AnnAssign(self, node): + # Pattern-matching: + # var_name: "VAR_ARGS_ARRAY" + if ( + node.value is None + and node.simple == 1 + and isinstance(node.target, ast.Name) + and isinstance(node.annotation, ast.Constant) + and node.annotation.value == "VAR_ARGS_ARRAY" + ): + self.inline_variables.add(node.target.id) + return [] + if node.value is not None: + node.value = self.visit(node.value) + if node.annotation is not None: + node.annotation = self.visit(node.annotation) + if node.target is not None: + node.target = self.visit(node.target) + return node + + def visit_arguments(self, node): + # Replace `args` annotated with `VAR_ARGS_ARRAY` + new_args = [] + for arg in node.args: + if ( + arg.annotation is not None + and isinstance(arg.annotation, ast.Constant) + and arg.annotation.value == "VAR_ARGS_ARRAY" + ): + self.inline_variables.add(arg.arg) + new_args += [ast.arg(f"{arg.arg}{i}") for i in range(self.N)] + continue + new_args.append(arg) + if node.vararg is not None: + self.inline_variables.add(node.vararg.arg) + new_args += [ast.arg(f"{node.vararg.arg}{i}") for i in range(self.N)] + node.vararg = None + new_args += node.kwonlyargs + node.kwonlyargs = [] + node.args = new_args + return node + + def visit_For(self, node): + if ( + not isinstance(node.iter, ast.Call) + or node.iter.func.id != "range" + or len(node.iter.args) != 1 + or not isinstance(node.iter.args[0], ast.Call) + or node.iter.args[0].func.id != "len" + or len(node.iter.args[0].args) != 1 + or node.iter.args[0].args[0].id not in self.inline_variables + ): + node.body = [self.visit(x) for x in node.body] + return node + # We know we have to modify this loop + new_nodes = [] + for i in range(self.N): + unroller = _ForLoopUnroller( + target=node.target.id, + inline_variables=self.inline_variables, + loop_iter=i, + ) + for body in node.body: + body = copy.deepcopy(body) + new_node = ast.fix_missing_locations(unroller.visit(body)) + new_node = self.visit(new_node) + new_nodes.append(new_node) + return new_nodes + + +# Hackfix to get access to get source-code for +# `exec`-created functions - see https://stackoverflow.com/a/69668999 +_getlines_orig = None +_FILENAME_TO_SRC: Dict[str, str] = {} + + +def _monkey_patched_getlines(filename, module_globals=None): + if filename in _FILENAME_TO_SRC: + return _FILENAME_TO_SRC[filename] + else: + return _getlines_orig(filename, module_globals) # type: ignore + + +@functools.lru_cache(None) +def unroll_varargs(kernel, N: int): + """ + Specializes a triton kernel with variable number of inputs + to a specific number of inputs `N`. + NOTE: Because it's quite costly to call `triton.jit`, + we cache the returned value with `lru_cache` + """ + global _FILENAME_TO_SRC, _getlines_orig + + k = triton.JITFunction(kernel.fn) + parsed = ast.parse(k.src) + nodeVisitor = _VisitorUnrollKernel(N=N) + parsed = nodeVisitor.visit(parsed) + parsed = ast.fix_missing_locations(parsed) + + # NOTE: `ast.unparse` requires python 3.9+ + if (sys.version_info.major, sys.version_info.minor) <= (3, 8): + raise RuntimeError("Error: This functionality requires python 3.9 or above") + new_src = ast.unparse(parsed) # type: ignore + + # Now we want to `eval` the function, but we need all this + # boilerplate code to make sure triton can run `inspect.getsource` + + fn_filename = f"" + + # Create function given source + code = compile(new_src, fn_filename, "exec") + + _locals: Dict[str, Any] = {} + exec(code, kernel.fn.__globals__, _locals) + assert len(_locals) == 1, len(_locals) + fn = next(iter(_locals.values())) + # Patch `getlines` only the first time + if not _FILENAME_TO_SRC: + _getlines_orig = linecache.getlines + linecache.getlines = _monkey_patched_getlines + _FILENAME_TO_SRC[fn_filename] = new_src + + jitted_fn = triton.jit(fn) + jitted_fn.src = new_src + return jitted_fn + + +# Note: just import this to make mypy happy +# when annotating variables with `VAR_ARGS_ARRAY` +VAR_ARGS_ARRAY = List[Any] diff --git a/pkgs/xformers/utils.py b/pkgs/xformers/utils.py new file mode 100644 index 0000000..606f896 --- /dev/null +++ b/pkgs/xformers/utils.py @@ -0,0 +1,79 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import importlib +import os +import sys +from collections import namedtuple +from dataclasses import fields +from typing import Any, Callable, Dict, List + +Item = namedtuple("Item", ["constructor", "config"]) + + +# credit: snippet used in ClassyVision (and probably other places) +def import_all_modules(root: str, base_module: str) -> List[str]: + modules: List[str] = [] + for file in os.listdir(root): + if file.endswith((".py", ".pyc")) and not file.startswith("_"): + module = file[: file.find(".py")] + if module not in sys.modules: + module_name = ".".join([base_module, module]) + importlib.import_module(module_name) + modules.append(module_name) + + return modules + + +def get_registry_decorator( + class_registry, name_registry, reference_class, default_config +) -> Callable[[str, Any], Callable[[Any], Any]]: + def register_item(name: str, config: Any = default_config): + """Registers a subclass. + + This decorator allows xFormers to instantiate a given subclass + from a configuration file, even if the class itself is not part of the + xFormers library.""" + + def register_cls(cls): + if name in class_registry: + raise ValueError("Cannot register duplicate item ({})".format(name)) + if not issubclass(cls, reference_class): + raise ValueError( + "Item ({}: {}) must extend the base class: {}".format( + name, cls.__name__, reference_class.__name__ + ) + ) + if cls.__name__ in name_registry: + raise ValueError( + "Cannot register item with duplicate class name ({})".format( + cls.__name__ + ) + ) + + class_registry[name] = Item(constructor=cls, config=config) + name_registry.add(cls.__name__) + return cls + + return register_cls + + return register_item + + +def generate_matching_config(superset: Dict[str, Any], config_class: Any) -> Any: + """Given a superset of the inputs and a reference config class, + return exactly the needed config""" + + # Extract the required fields + field_names = list(map(lambda x: x.name, fields(config_class))) + subset = {k: v for k, v in superset.items() if k in field_names} + + # The missing fields get Noned + for k in field_names: + if k not in subset.keys(): + subset[k] = None + + return config_class(**subset) diff --git a/pkgs/xformers/version.py b/pkgs/xformers/version.py new file mode 100644 index 0000000..1f8ed74 --- /dev/null +++ b/pkgs/xformers/version.py @@ -0,0 +1,2 @@ +# noqa: C801 +__version__ = "0.0.22" diff --git a/prefix_prefill.py b/prefix_prefill.py new file mode 100644 index 0000000..9a39e2b --- /dev/null +++ b/prefix_prefill.py @@ -0,0 +1,865 @@ +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return + + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + k_scale: float = 1.0, + v_scale: float = 1.0, + alibi_slopes=None, + sliding_window=None): + + BLOCK = 128 if current_platform.has_device_capability(80) else 64 + NUM_WARPS = 8 + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + if q.dtype is torch.float32: + BLOCK = BLOCK // 2 + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + import time + ts_beg = time.time() + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, + num_warps=NUM_WARPS, + num_stages=1, + ) + elapsed = time.time() - ts_beg + #print(f'{elapsed}: {BLOCK=}, {Lk=}, {Lk_padded=}, {BLOCK=}, {sliding_window=}, {NUM_WARPS=}') + return diff --git a/vllm/__init__.py b/vllm/__init__.py new file mode 100644 index 0000000..8f477ea --- /dev/null +++ b/vllm/__init__.py @@ -0,0 +1,36 @@ +"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.llm import LLM +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.model_executor.models import ModelRegistry +from vllm.outputs import (CompletionOutput, EmbeddingOutput, + EmbeddingRequestOutput, RequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams + +from .version import __version__, __version_tuple__ + +__all__ = [ + "__version__", + "__version_tuple__", + "LLM", + "ModelRegistry", + "PromptType", + "TextPrompt", + "TokensPrompt", + "SamplingParams", + "RequestOutput", + "CompletionOutput", + "EmbeddingOutput", + "EmbeddingRequestOutput", + "LLMEngine", + "EngineArgs", + "AsyncLLMEngine", + "AsyncEngineArgs", + "initialize_ray_cluster", + "PoolingParams", +] diff --git a/vllm/__pycache__/__init__.cpython-310.pyc b/vllm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..1190aa9 Binary files /dev/null and b/vllm/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/__pycache__/_core_ext.cpython-310.pyc b/vllm/__pycache__/_core_ext.cpython-310.pyc new file mode 100644 index 0000000..aaa4f73 Binary files /dev/null and b/vllm/__pycache__/_core_ext.cpython-310.pyc differ diff --git a/vllm/__pycache__/_custom_ops.cpython-310.pyc b/vllm/__pycache__/_custom_ops.cpython-310.pyc new file mode 100644 index 0000000..4f0ece7 Binary files /dev/null and b/vllm/__pycache__/_custom_ops.cpython-310.pyc differ diff --git a/vllm/__pycache__/_ipex_ops.cpython-310.pyc b/vllm/__pycache__/_ipex_ops.cpython-310.pyc new file mode 100644 index 0000000..69c6c0c Binary files /dev/null and b/vllm/__pycache__/_ipex_ops.cpython-310.pyc differ diff --git a/vllm/__pycache__/beam_search.cpython-310.pyc b/vllm/__pycache__/beam_search.cpython-310.pyc new file mode 100644 index 0000000..cad0282 Binary files /dev/null and b/vllm/__pycache__/beam_search.cpython-310.pyc differ diff --git a/vllm/__pycache__/block.cpython-310.pyc b/vllm/__pycache__/block.cpython-310.pyc new file mode 100644 index 0000000..c27aa2b Binary files /dev/null and b/vllm/__pycache__/block.cpython-310.pyc differ diff --git a/vllm/__pycache__/config.cpython-310.pyc b/vllm/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..e54d5ee Binary files /dev/null and b/vllm/__pycache__/config.cpython-310.pyc differ diff --git a/vllm/__pycache__/connections.cpython-310.pyc b/vllm/__pycache__/connections.cpython-310.pyc new file mode 100644 index 0000000..d662859 Binary files /dev/null and b/vllm/__pycache__/connections.cpython-310.pyc differ diff --git a/vllm/__pycache__/envs.cpython-310.pyc b/vllm/__pycache__/envs.cpython-310.pyc new file mode 100644 index 0000000..bde38c9 Binary files /dev/null and b/vllm/__pycache__/envs.cpython-310.pyc differ diff --git a/vllm/__pycache__/forward_context.cpython-310.pyc b/vllm/__pycache__/forward_context.cpython-310.pyc new file mode 100644 index 0000000..539b909 Binary files /dev/null and b/vllm/__pycache__/forward_context.cpython-310.pyc differ diff --git a/vllm/__pycache__/logger.cpython-310.pyc b/vllm/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000..15c9df2 Binary files /dev/null and b/vllm/__pycache__/logger.cpython-310.pyc differ diff --git a/vllm/__pycache__/outputs.cpython-310.pyc b/vllm/__pycache__/outputs.cpython-310.pyc new file mode 100644 index 0000000..448d4fe Binary files /dev/null and b/vllm/__pycache__/outputs.cpython-310.pyc differ diff --git a/vllm/__pycache__/pooling_params.cpython-310.pyc b/vllm/__pycache__/pooling_params.cpython-310.pyc new file mode 100644 index 0000000..05d2402 Binary files /dev/null and b/vllm/__pycache__/pooling_params.cpython-310.pyc differ diff --git a/vllm/__pycache__/sampling_params.cpython-310.pyc b/vllm/__pycache__/sampling_params.cpython-310.pyc new file mode 100644 index 0000000..31c99bc Binary files /dev/null and b/vllm/__pycache__/sampling_params.cpython-310.pyc differ diff --git a/vllm/__pycache__/scalar_type.cpython-310.pyc b/vllm/__pycache__/scalar_type.cpython-310.pyc new file mode 100644 index 0000000..8aa5fca Binary files /dev/null and b/vllm/__pycache__/scalar_type.cpython-310.pyc differ diff --git a/vllm/__pycache__/scripts.cpython-310.pyc b/vllm/__pycache__/scripts.cpython-310.pyc new file mode 100644 index 0000000..45eb857 Binary files /dev/null and b/vllm/__pycache__/scripts.cpython-310.pyc differ diff --git a/vllm/__pycache__/sequence.cpython-310.pyc b/vllm/__pycache__/sequence.cpython-310.pyc new file mode 100644 index 0000000..ca6591b Binary files /dev/null and b/vllm/__pycache__/sequence.cpython-310.pyc differ diff --git a/vllm/__pycache__/tracing.cpython-310.pyc b/vllm/__pycache__/tracing.cpython-310.pyc new file mode 100644 index 0000000..b4dd2a9 Binary files /dev/null and b/vllm/__pycache__/tracing.cpython-310.pyc differ diff --git a/vllm/__pycache__/utils.cpython-310.pyc b/vllm/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..3b90b20 Binary files /dev/null and b/vllm/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/__pycache__/version.cpython-310.pyc b/vllm/__pycache__/version.cpython-310.pyc new file mode 100644 index 0000000..de6fe1d Binary files /dev/null and b/vllm/__pycache__/version.cpython-310.pyc differ diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py new file mode 100644 index 0000000..a27b864 --- /dev/null +++ b/vllm/_core_ext.py @@ -0,0 +1,278 @@ +import importlib.util +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) +core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +if TYPE_CHECKING or not core_C_available: + # On platforms were we cannot use/build the C++ core extension (i.e. namely + # neuron and tpu), we define the mock ScalarType class here that partially + # mimics the C++ ScalarType class. + # + # We also use this provide type signatures to the Python LSP for the methods + # in the C++ ScalarType class. So these type signatures should be kept + # in sync with csrc/core/scalar_type.hpp + + from dataclasses import dataclass + + @dataclass(frozen=True) + class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + _finite_values_only: bool = False + """ + Private: if NANs are supported, used `has_infs()` instead. + """ + + nan_repr: int = NanRepr.IEEE_754.value + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + @property + def size_bits(self): + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + raise NotImplementedError + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + ... + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + return cls(size_bits - 1, size_bits, bias if bias else 0, True) + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + return cls(size_bits, size_bits, bias if bias else 0, False) + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True) + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: int) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + return cls(exponent, mantissa, 0, True, finite_values_only, + nan_repr) + +elif core_C_available: + try: + import vllm._core_C # noqa: F401 + except ImportError as e: + logger.warning("Failed to import from vllm._core_C with %r", e) + + ScalarType = torch.classes._core_C.ScalarType + + if (hasattr(torch, "_library") + and hasattr(torch._library, "register_fake_class")): + # Needed for dynamo support of ScalarType. + @torch._library.register_fake_class("_core_C::ScalarType") + class FakeScalarType: + + def __init__(self, scalar_type): + self.ScalarType = scalar_type + + def bias_getter(self) -> int: + return self.ScalarType.bias + + def exponent_getter(self) -> int: + return self.ScalarType.exponent + + def mantissa_getter(self) -> int: + return self.ScalarType.mantissa + + def signed_getter(self) -> bool: + return self.ScalarType.signed + + def size_bits_getter(self) -> int: + return self.ScalarType.size_bits + + @property + def size_bits(self) -> int: + return self.ScalarType.size_bits + + def min(self) -> Union[int, float]: + return self.ScalarType.min() + + def max(self) -> Union[int, float]: + return self.ScalarType.max() + + def is_signed(self) -> bool: + return self.ScalarType.is_signed() + + def is_floating_point(self) -> bool: + return self.ScalarType.is_floating_point() + + def is_integer(self) -> bool: + return self.ScalarType.is_integer() + + def has_bias(self) -> bool: + return self.ScalarType.has_bias() + + def has_infs(self) -> bool: + return self.ScalarType.has_infs() + + def has_nans(self) -> bool: + return self.ScalarType.has_nans() + + def is_ieee_754(self) -> bool: + return self.ScalarType.is_ieee_754() + + def __str__(self) -> str: + return self.ScalarType.__str__() + + def __repr__(self) -> str: + return self.ScalarType.__repr__() + + def __len__(self) -> int: + return self.ScalarType.__len__() + + def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: + return torch.classes._core_C.ScalarType.__obj_flatten__( + self.ScalarType) + + @classmethod + def __obj_unflatten__( + cls, flat_type: Tuple[Tuple[str, Any], + ...]) -> 'ScalarType': + return cls( + torch.classes._core_C.ScalarType.__obj_unflatten__( + flat_type)) + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.int_(size_bits, bias) + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.uint(size_bits, bias) + + @classmethod + def float_IEEE754(cls, exponent: int, + mantissa: int) -> 'ScalarType': + return ScalarType.float_IEEE754(exponent, mantissa) + + @classmethod + def float_(cls, exponent: int, mantissa: int, + finite_values_only: bool, + nan_repr: int) -> 'ScalarType': + return ScalarType.float_(exponent, mantissa, + finite_values_only, nan_repr) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py new file mode 100644 index 0000000..ac4cce9 --- /dev/null +++ b/vllm/_custom_ops.py @@ -0,0 +1,1105 @@ +import contextlib +import functools +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Dict, Any + +import torch +import torch.library + +import vllm.envs as envs +from vllm._core_ext import ScalarType +from vllm.logger import init_logger +from vllm.platforms import current_platform +# import ixformer.inference.functions as ops +import ixformer.functions as ixf_F +from ixformer.distributed import _distributed as cdist +import torch.nn.functional as F + +logger = init_logger(__name__) + +supports_moe_ops = True + +if TYPE_CHECKING: + + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + try: + from torch.library import impl_abstract as register_fake + except: + def register_fake(fn): + return lambda name: fn + + +def hint_on_error(fn): + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + + except NotImplementedError as e: + msg = ( + "Error in calling custom op %s: %s\n" + "Not implemented or built, mostly likely because the current current device " + "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set " + "incorrectly while building)") + logger.error(msg, fn.__name__, e) + raise NotImplementedError(msg % (fn.__name__, e)) from e + except AttributeError as e: + msg = ( + "Error in calling custom op %s: %s\n" + "Possibly you have built or installed an obsolete version of vllm.\n" + "Please try a clean build and install of vllm," + "or remove old built files such as vllm/*cpython*.so and build/ ." + ) + logger.error(msg, fn.__name__, e) + raise e + + return wrapper + + +# activation ops +def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ixf_F.silu_and_mul(x, out) + + +def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ixf_F.gelu_and_mul(x, out) + + +def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ixf_F.gelu_tanh_and_mul(x, out) + + +def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(F.gelu(x,approximate="tanh")) + return out + + +def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(F.gelu(x,approximate="tanh")) + return out + + +def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + out.copy_(F.gelu(x,approximate="tanh")) + return out + + + +def paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes=None, + kv_cache_dtype=None, +): + return ixf_F.vllm_single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + raise NotImplementedError() + + +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + raise NotImplementedError() + + +# pos encoding ops +def rotary_embedding( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool, +) -> None: + ixf_F.vllm_rotary_embedding_neox(positions, query, key, head_size, + cos_sin_cache, is_neox) + + +def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ixf_F.vllm_batched_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox, rot_dim, + cos_sin_cache_offsets) + + +# layer norm ops +def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> None: + ixf_F.rms_norm(input, weight, out, epsilon) + + +def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float, + residual_alpha: Optional[float] = 1) -> None: + ixf_F.fused_add_rms_norm(input, residual, weight, epsilon) + + +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: + """Advance a step on GPU for existing inputs for a multi-step runner""" + return ixf_F.advance_step_flashattn(num_seqs, num_queries, block_size, + input_tokens, + sampled_token_ids, + input_positions, + seq_lens, slot_mapping, + block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + raise NotImplementedError("FIX SOON") + + +# quantization ops +# awq +def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + raise NotImplementedError() + + +def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, + pack_factor, group_size: int = 128) -> torch.Tensor: + return ixf_F.quantized_linear(input, qweight, scales,"awq",32 // pack_factor,qzeros=qzeros,group_size=group_size) + + +# gptq +def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, + b_g_idx: torch.Tensor, use_exllama: bool, + bit: int) -> torch.Tensor: + batch = a.shape[0] + if batch <= 8: + return ixf_F.quantized_linear(a,b_q_weight,b_gptq_scales,"gptq",4,b_gptq_qzeros,None,group_size=128) + o_dtype_str = "fp16" if a.dtype == torch.half else "bf16" + deq_w = ixf_F.quantized_weight_dequant(b_q_weight,b_gptq_scales,"gptq",o_dtype_str,4,b_gptq_qzeros,group_size=128) + return torch.matmul(a,deq_w) + + +if hasattr(torch.ops._C, "gptq_gemm"): + + @register_fake("_C::gptq_gemm") + def _gptq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_gptq_qzeros: torch.Tensor, + b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, + use_exllama: bool, bit: int) -> torch.Tensor: + return torch.empty((a.size(0), b_q_weight.size(1)), + dtype=a.dtype, + device=a.device) + + +def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, + bit: int) -> None: + return ixf_F.vllm_gptq_shuffle(q_weight,q_perm) + + +# marlin +def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + raise NotImplementedError() + + +# marlin_24 +def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, b_q_type: ScalarType, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + raise NotImplementedError() + + +if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): + + @register_fake("_C::gptq_marlin_24_gemm") + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: int, + size_n: int, size_k: int) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::gptq_marlin_gemm") + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int, + n: int) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::marlin_qqq_gemm") + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, + size_k: int) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::marlin_gemm") + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, + size_k: int) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake("_C::awq_dequantize") + def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor, + zeros: torch.Tensor, split_k_iters: int, thx: int, + thy: int) -> torch.Tensor: + in_c = qweight.size(0) + qout_c = qweight.size(1) + out_c = qout_c * 8 + return torch.empty((in_c, out_c), + dtype=scales.dtype, + device=scales.device) + + @register_fake("_C::awq_gemm") + def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor, + qzeros: torch.Tensor, scales: torch.Tensor, + split_k_iters: int) -> torch.Tensor: + num_in_feats = input.size(0) + return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8), + dtype=input.dtype, + device=input.device).sum(0) + + @register_fake("_C::aqlm_gemm") + def _aqlm_gemm_fake(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: List[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + out_features = codes.size(0) * codebooks.size(2) + flat_input = input.reshape((-1, input.size(-1))) + flat_output = torch.empty((flat_input.size(0), out_features), + dtype=input.dtype, + device=input.device) + + output_sizes = list(input.shape) + output_sizes.pop() + output_sizes.append(-1) + return flat_output.reshape(tuple(output_sizes)) + + @register_fake("_C::aqlm_dequant") + def _aqlm_dequant_fake( + codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: List[int]) -> torch.Tensor: + in_features = codes.size(1) * 8 + out_features = codes.size(0) + return torch.empty((out_features, in_features), + dtype=codebooks.dtype, + device=codebooks.device) + + @register_fake("_C::fp8_marlin_gemm") + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, + size_k: int) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake("_C::machete_gemm") + def machete_gemm_fake( + a: torch.Tensor, + # Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, + ) -> torch.Tensor: + m = a.size(0) + n = b_q.size(1) + return torch.empty((m, n), device=a.device, dtype=a.dtype) + + @register_fake("_C::machete_prepack_B") + def machete_prepack_B_fake(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + return torch.empty_like(b_q_weight, + memory_format=torch.contiguous_format) + + @register_fake("_C::causal_conv1d_fwd") + def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + cu_seq_len: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.empty_like(x) + + @register_fake("_C::causal_conv1d_update") + def causal_conv1d_update_fake( + x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: + return torch.empty_like(x) + + @register_fake("_C::selective_scan_fwd") + def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, + A: torch.Tensor, B: torch.Tensor, + C: torch.Tensor, D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, + cu_seq_len: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + ssm_states: Optional[torch.Tensor]) -> None: + return None + + +# cutlass +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return True + + +def cutlass_scaled_mm(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + ixf_F.w8a8(a, b.transpose(0,1), scale_a, scale_b, bias, output=out, out_dtype=out_dtype) + + return out + + +def cutlass_scaled_mm_azp(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError() + + +# aqlm +def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor, + codebooks: torch.Tensor, scales: torch.Tensor, + codebook_partition_sizes: List[int], + bias: Optional[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError() + + +def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, + codebook_partition_sizes: List[int]) -> torch.Tensor: + raise NotImplementedError() + + +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError() + + +# gptq_marlin +def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError() + + +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + raise NotImplementedError() + + +def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * (num_bits // 2)), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.awq_marlin_repack(b_q_weight[e], size_k, + size_n, num_bits) + return output + + +def gptq_marlin_gemm(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False) -> torch.Tensor: + raise NotImplementedError() + + +# fp8 marlin +def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: int, size_n: int, + size_k: int) -> torch.Tensor: + raise NotImplementedError() + + +# machete +def machete_supported_schedules(b_type: ScalarType) -> List[str]: + raise NotImplementedError() + + +def machete_gemm( + a: torch.Tensor, + b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B + b_type: ScalarType, + b_scales: Optional[torch.Tensor] = None, + b_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + c: Optional[torch.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + raise NotImplementedError() + + +def machete_prepack_B(b_q_weight: torch.Tensor, + b_type: ScalarType) -> torch.Tensor: + raise NotImplementedError() + + +if hasattr(torch.ops._C, "permute_cols"): + + @register_fake("_C::permute_cols") + def _permute_cols_fake(a: torch.Tensor, + perm: torch.Tensor) -> torch.Tensor: + return torch.empty_like(a) + + +def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: + raise NotImplementedError() + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + raise NotImplementedError() + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is + None), "azp must only be provided for asymmetric quantization." + ixf_F.static_scaled_int8_quant(output, input, scale) + return output, scale, None + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ixf_F.dynamic_scaled_int8_quant(output, input, input_scales) + return output, input_scales, input_azp + + +# qqq ops +def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: int, size_n: int, size_k: int) -> torch.Tensor: + raise NotImplementedError() + + +# gguf +def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, + n: int) -> torch.Tensor: + raise NotImplementedError() + + +def ggml_mul_mat_vec_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + raise NotImplementedError() + + +def ggml_mul_mat_a8( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: int, +) -> torch.Tensor: + raise NotImplementedError() + + +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + conv_states: Optional[torch.Tensor], + query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + raise NotImplementedError() + + +def causal_conv1d_update( + x: torch.Tensor, conv_state: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], silu_activation: bool, + cache_seqlens: Optional[torch.Tensor], + conv_state_indices: Optional[torch.Tensor]) -> torch.Tensor: + raise NotImplementedError() + + +def selective_scan_fwd( + u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, B: torch.Tensor, + C: torch.Tensor, D_: Optional[torch.Tensor], + z_: Optional[torch.Tensor], delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, query_start_loc: Optional[torch.Tensor], + cache_indices: Optional[torch.Tensor], + has_initial_state: Optional[torch.Tensor], ssm_states: torch.Tensor): + raise NotImplementedError() + + +# moe +def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, + block_size: int, sorted_token_ids: torch.Tensor, + experts_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + ixf_F.vllm_moe_align_block_size(topk_ids, num_experts, block_size, + sorted_token_ids, experts_ids, + num_tokens_post_pad) + + +def invoke_fused_moe_kernel( + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, + top_k: int, + config: Dict[str, Any], + compute_type, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, +) -> None: + ixf_F.vllm_invoke_fused_moe_kernel( + A, + B, + C, + topk_weights, + topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + mul_routed_weight, + top_k, + config['BLOCK_SIZE_M'] + ) + + +def topk_softmax(topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indicies: torch.Tensor, + gating_output: float) -> None: + ixf_F.vllm_moe_topk_softmax(topk_weights, topk_ids, + token_expert_indicies, gating_output) + + +if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"): + + @register_fake("_moe_C::marlin_gemm_moe") + def marlin_gemm_moe_fake(a: torch.Tensor, b_q_weights: torch.Tensor, + sorted_ids: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, b_scales: torch.Tensor, + b_zero_points: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, + b_q_type: ScalarType, size_m: int, size_n: int, + size_k: int, is_k_full: bool, num_experts: int, + topk: int, moe_block_size: int, + replicate_input: bool, + apply_weights: bool) -> torch.Tensor: + return torch.empty((size_m, topk, size_n), + dtype=a.dtype, + device=a.device) + + +def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + slot_mapping = slot_mapping.to(torch.int32) + ixf_F.vllm_cache_ops_reshape_and_cache(key, value, key_cache, + value_cache, slot_mapping) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ixf_F.reshape_and_cache_flash(key, value, key_cache, + value_cache, slot_mapping, + kv_cache_dtype, k_scale, + v_scale) + +def reshape_and_cache_flashinfer( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, # for fp8 + v_scale: float, # for fp8 + kv_cache_format: str = "NHD", + key_cache_scales: torch.Tensor = None, # for int8 + value_cache_scales: torch.Tensor = None, # for int8 +) -> None: + ixf_F.paged_attention_cache_appended( + key, + value, + key_cache, + value_cache, + slot_mapping, + kv_cache_format, + key_cache_scales, + value_cache_scales, + ) + +def copy_blocks(key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor) -> None: + ixf_F.copy_blocks(key_caches, value_caches, block_mapping) + + +def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + ixf_F.swap_blocks(src, dst, block_mapping) + + +def convert_fp8(output: torch.Tensor, + input: torch.Tensor, + scale: float = 1.0, + kv_dtype: str = "fp8") -> None: + raise NotImplementedError() + + +def get_device_attribute(attribute: int, device: int) -> int: + raise NotImplementedError() + + +def get_max_shared_memory_per_block_device_attribute(device: int) -> int: + return 32 * 1024 + + +# custom ar +def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, + handles: List[str], offsets: List[int], rank: int, + full_nvlink: bool) -> int: + raise NotImplementedError() + + +def should_custom_ar(inp: torch.Tensor, max_size: int, world_size: int, + full_nvlink: bool) -> bool: + raise NotImplementedError() + + +def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + raise NotImplementedError() + + +def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, + out: torch.Tensor) -> None: + raise NotImplementedError() + + +def dispose(fa: int) -> None: + raise NotImplementedError() + + +def meta_size() -> int: + raise NotImplementedError() + + +def register_buffer(fa: int, t: torch.Tensor, handles: List[str], + offsets: List[int]) -> None: + raise NotImplementedError() + + +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[str], List[int]]: + raise NotImplementedError() + + +def register_graph_buffers(fa: int, handles: List[str], + offsets: List[List[int]]) -> None: + raise NotImplementedError() + + +# Add our new features here.. + +# broadcast +class Async_helper(): + # For now, the comm and the other kernels are in the same stream, so we can remove the stream wait.. + def wait(self,): + return True + + +def broadcast(tensor, src=0, group=None, async_op=False): + cdist.broadcast(tensor,src,group,async_op=True) + if async_op: + return Async_helper() + else: + pass + +# w8a16 +def linear_w8a16(x: torch.Tensor, qweight: torch.Tensor, scales:torch.Tensor, + group_size: int = -1, format: str = "TN")-> torch.Tensor: + return ixf_F.w8a16(x, qweight, scales, format="TN", group_size=group_size) + + +## lora sgmv / bgmv +def sbgmv_expand(x: torch.Tensor, + w_t_all: torch.Tensor, + y: torch.Tensor, + b_seq_start_loc: torch.Tensor = None, + seq_len_tensor: torch.Tensor = None, + lora_indices_tensor: torch.Tensor = None, + batches: int = -1, + max_seq_length: int = -1, + token_nums: int = -1, + add_input=True, + ): + ''' + x: inputs + w_t_all: lora weight + y: output + + y += x@wt_t_all + ''' + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert w_t_all.dtype in [ + torch.float16, + torch.bfloat16, + ] + + assert x.is_contiguous() + # assert y.is_contiguous() + if x.dtype == torch.float: + x = x.to(w_t_all.dtype) + + if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) + assert w_t_all.size(1) == 1 + w_t_all = w_t_all.squeeze(dim=1) + else: + assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) + assert w_t_all.is_contiguous() + + assert add_input == True + + lora_indices = lora_indices_tensor.cpu().tolist() + lora_num = w_t_all.shape[0] + + ## 单一lora model, 且所有request均使用lora + if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): + if lora_indices[0] != -1: + w_t = w_t_all[0] + y += torch.matmul(x, w_t.t()) + ## 多个lora model + else: + ## prefill + if batches != -1: + for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor): + if lora_id != -1: + xi = x[start: start+seq_len] + w_t = w_t_all[lora_id] + y[start:start+seq_len] += (xi @ w_t.t()) + ## decode + else: + batches = x.shape[0] + for i, lora_id in zip(range(batches), lora_indices): + if lora_id != -1: + xi = x[i].unsqueeze(0) + w_t = w_t_all[lora_id] + y[i] += (xi @ w_t.t()).squeeze(0) + + return y + + +def sbgmv_shrink(x: torch.Tensor, + w_t_all: torch.Tensor, + y: torch.Tensor, + b_seq_start_loc: torch.Tensor = None, + seq_len_tensor: torch.Tensor = None, + lora_indices_tensor: torch.Tensor = None, + batches: int = -1, + max_seq_length: int = -1, + token_nums: int = -1, + scale: float = 1.0,): + """ + xx: inputs + w_t_all: lora weight + y: output + scale: float + + y = x@w_t_all * scale + """ + assert x.dtype == w_t_all.dtype + assert x.dtype in [torch.float16, torch.bfloat16] + assert x.is_contiguous() + assert y.is_contiguous() + + if w_t_all.ndim == 4: # shape:(lora_num,1,size,rank) + assert w_t_all.size(1) == 1 + w_t_all = w_t_all.squeeze(dim=1) + else: + assert w_t_all.ndim == 3 # shape:(lora_num,size,rank) + assert w_t_all.is_contiguous() + + lora_num = w_t_all.shape[0] + lora_indices = lora_indices_tensor.cpu().tolist() + + ## 单一lora model, 且所有request均使用lora + if lora_num == 1 and all(x == lora_indices[0] for x in lora_indices): + if lora_indices[0] != -1: + w_t = w_t_all[0] + y = torch.matmul(x, w_t.t()) * scale + ## 多个lora model + else: + ## prefill + if batches != -1: + for i, lora_id, start, seq_len in zip(range(batches), lora_indices, b_seq_start_loc, seq_len_tensor): + if lora_id != -1: + xi = x[start: start+seq_len] + w_t = w_t_all[lora_id] + y[start:start+seq_len] = (xi @ w_t.t())* scale + ## decode + else: + batches = x.shape[0] + for i, lora_id in zip(range(batches), lora_indices): + if lora_id != -1: + xi = x[i].unsqueeze(0) + w_t = w_t_all[lora_id] + y[i] = (xi @ w_t.t()).squeeze(0) * scale + + return y + +# temporary fix for https://github.com/vllm-project/vllm/issues/5456 +# TODO: remove this in v0.6.0 +names_and_values = globals() +names_and_values_to_update = {} +# prepare variables to avoid dict size change during iteration +k, v, arg = None, None, None +fn_type = type(lambda x: x) +for k, v in names_and_values.items(): + # find functions that are defined in this file and have torch.Tensor + # in their annotations. `arg == "torch.Tensor"` is used to handle + # the case when users use `import __annotations__` to turn type + # hints into strings. + if isinstance(v, fn_type) \ + and v.__code__.co_filename == __file__ \ + and any(arg is torch.Tensor or arg == "torch.Tensor" + for arg in v.__annotations__.values()): + names_and_values_to_update[k] = hint_on_error(v) + +names_and_values.update(names_and_values_to_update) +del names_and_values_to_update, names_and_values, v, k, fn_type diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py new file mode 100644 index 0000000..31fcc4c --- /dev/null +++ b/vllm/_ipex_ops.py @@ -0,0 +1,243 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +try: + import intel_extension_for_pytorch as ipex +except ImportError as e: + logger.warning("Import error msg: %s", e.msg) + + +class ipex_ops: + + @staticmethod + def _reshape_activation_tensor( + x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num = x.size(0) + d = x.size(1) // 2 + x = x.reshape(num, 2, d) + x1, x2 = torch.chunk(x, chunks=2, dim=1) + x1 = x1.reshape(num, d) + x2 = x2.reshape(num, d) + return x1, x2 + + @staticmethod + def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.silu_and_mul(x, out) + + @staticmethod + def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_and_mul(x, out) + + @staticmethod + def gelu_fast(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_new(x: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.gelu(x) + + @staticmethod + def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: + ipex.llm.functional.gelu_quick(x, out) + + @staticmethod + def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + head_mapping = torch.arange( + 0, + num_kv_heads, + device=query.device, + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(num_queries_per_tokens).flatten() + # todo: ipex will refactor namespace + torch.xpu.paged_attention_v1( # type: ignore + out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + block_size: int, + max_context_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> None: + assert kv_cache_dtype == "auto" + num_heads = out.size(1) + num_queries_per_tokens = num_heads // num_kv_heads + head_mapping = torch.arange( + 0, + num_kv_heads, + dtype=torch.int32, + device=query.device, + ).view(num_kv_heads, + 1).repeat_interleave(num_queries_per_tokens).flatten() + # todo: ipex will refactor namespace + torch.xpu.paged_attention_v2( # type: ignore + out, + exp_sum, + max_logits, + tmp_out, + query.contiguous(), + key_cache.view_as(value_cache), + value_cache, + head_mapping, + block_tables, + context_lens, + scale, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def rotary_embedding( + positions: torch.Tensor, # [batch_size, seq_len] + query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size] + key: torch.Tensor, # [batch_size, seq_len, num_kv_heads*head_size] + head_size: int, + cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] + is_neox: bool, + ) -> None: + rot_dim = cos_sin_cache.size(1) + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim) + + @staticmethod + def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, + key: torch.Tensor, head_size: int, + cos_sin_cache: torch.Tensor, is_neox: bool, + rot_dim: int, + cos_sin_cache_offsets: torch.Tensor) -> None: + ipex.llm.functional.rotary_embedding_batched(positions, query, key, + head_size, cos_sin_cache, + is_neox, rot_dim, + cos_sin_cache_offsets) + + @staticmethod + def rms_norm(input: torch.Tensor, weight: torch.Tensor, + epsilon: float) -> torch.Tensor: + return ipex.llm.functional.rms_norm(input, weight, epsilon) + + @staticmethod + def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, + weight: torch.Tensor, epsilon: float) -> None: + tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None, + epsilon, True) + input.copy_(tmp) + + @staticmethod + def varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + seqlen_q: torch.Tensor, + seqlen_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + pdropout: float, + softmax_scale: float, + zero_tensors: bool, + is_causal: bool, + return_softmax: bool, + gen_: torch.Generator, + ) -> None: + ipex.llm.functional.varlen_attention(query.contiguous(), + key.contiguous(), + value.contiguous(), out, + seqlen_q.int(), seqlen_k.int(), + max_seqlen_q, max_seqlen_k, + pdropout, softmax_scale, + zero_tensors, is_causal, + return_softmax, gen_) + + @staticmethod + def reshape_and_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + assert kv_cache_dtype == "auto" + ipex.llm.modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def copy_blocks(key_caches: List[torch.Tensor], + value_caches: List[torch.Tensor], + block_mapping: torch.Tensor) -> None: + torch.xpu.copy_blocks( # type: ignore + key_caches, + value_caches, + block_mapping, + ) + + @staticmethod + def swap_blocks(src: torch.Tensor, dst: torch.Tensor, + block_mapping: torch.Tensor) -> None: + torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/adapter_commons/__pycache__/__init__.cpython-310.pyc b/vllm/adapter_commons/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e3d7395 Binary files /dev/null and b/vllm/adapter_commons/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/adapter_commons/__pycache__/layers.cpython-310.pyc b/vllm/adapter_commons/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000..58eba13 Binary files /dev/null and b/vllm/adapter_commons/__pycache__/layers.cpython-310.pyc differ diff --git a/vllm/adapter_commons/__pycache__/models.cpython-310.pyc b/vllm/adapter_commons/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000..c430e4e Binary files /dev/null and b/vllm/adapter_commons/__pycache__/models.cpython-310.pyc differ diff --git a/vllm/adapter_commons/__pycache__/request.cpython-310.pyc b/vllm/adapter_commons/__pycache__/request.cpython-310.pyc new file mode 100644 index 0000000..46b7e60 Binary files /dev/null and b/vllm/adapter_commons/__pycache__/request.cpython-310.pyc differ diff --git a/vllm/adapter_commons/__pycache__/utils.cpython-310.pyc b/vllm/adapter_commons/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..bd20645 Binary files /dev/null and b/vllm/adapter_commons/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/adapter_commons/__pycache__/worker_manager.cpython-310.pyc b/vllm/adapter_commons/__pycache__/worker_manager.cpython-310.pyc new file mode 100644 index 0000000..0637c3c Binary files /dev/null and b/vllm/adapter_commons/__pycache__/worker_manager.cpython-310.pyc differ diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 0000000..3ed6067 --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 0000000..a5c04ab --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Hashable, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: Hashable, value: Optional[T]): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def capacity(self) -> int: + raise NotImplementedError + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + raise NotImplementedError + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 0000000..2bb17fd --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod + + +class AdapterRequest(ABC): + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self) -> int: + raise NotImplementedError + + def __post_init__(self) -> None: + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 0000000..1e9adca --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,90 @@ +from typing import Any, Callable, Dict, Optional, Set + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id) + + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 0000000..83929e8 --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Set + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + raise NotImplementedError + + @abstractmethod + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + raise NotImplementedError + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_all_adapters(self) -> None: + raise NotImplementedError + + @abstractmethod + def list_adapters(self) -> Set[int]: + raise NotImplementedError diff --git a/vllm/assets/__init__.py b/vllm/assets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/assets/__pycache__/__init__.cpython-310.pyc b/vllm/assets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6f5c7c6 Binary files /dev/null and b/vllm/assets/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/assets/__pycache__/audio.cpython-310.pyc b/vllm/assets/__pycache__/audio.cpython-310.pyc new file mode 100644 index 0000000..4a4f7a8 Binary files /dev/null and b/vllm/assets/__pycache__/audio.cpython-310.pyc differ diff --git a/vllm/assets/__pycache__/base.cpython-310.pyc b/vllm/assets/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..a5a838c Binary files /dev/null and b/vllm/assets/__pycache__/base.cpython-310.pyc differ diff --git a/vllm/assets/__pycache__/image.cpython-310.pyc b/vllm/assets/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000..644befd Binary files /dev/null and b/vllm/assets/__pycache__/image.cpython-310.pyc differ diff --git a/vllm/assets/__pycache__/video.cpython-310.pyc b/vllm/assets/__pycache__/video.cpython-310.pyc new file mode 100644 index 0000000..6d2fbac Binary files /dev/null and b/vllm/assets/__pycache__/video.cpython-310.pyc differ diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py new file mode 100644 index 0000000..49bb6ae --- /dev/null +++ b/vllm/assets/audio.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass +from typing import Literal, Tuple +from urllib.parse import urljoin + +import librosa +import numpy as np + +from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL + +ASSET_DIR = "multimodal_asset" + + +@dataclass(frozen=True) +class AudioAsset: + name: Literal["winning_call", "mary_had_lamb"] + + @property + def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]: + + audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + y, sr = librosa.load(audio_path, sr=None) + assert isinstance(sr, int) + return y, sr + + @property + def url(self) -> str: + return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/assets/base.py b/vllm/assets/base.py new file mode 100644 index 0000000..f97e8c2 --- /dev/null +++ b/vllm/assets/base.py @@ -0,0 +1,39 @@ +from functools import lru_cache +from pathlib import Path +from typing import Optional + +import vllm.envs as envs +from vllm.connections import global_http_connection +from vllm.envs import VLLM_IMAGE_FETCH_TIMEOUT + +vLLM_S3_BUCKET_URL = "https://vllm-public-assets.s3.us-west-2.amazonaws.com" + + +def get_cache_dir() -> Path: + """Get the path to the cache for storing downloaded assets.""" + path = Path(envs.VLLM_ASSETS_CACHE) + path.mkdir(parents=True, exist_ok=True) + + return path + + +@lru_cache +def get_vllm_public_assets(filename: str, + s3_prefix: Optional[str] = None) -> Path: + """ + Download an asset file from ``s3://vllm-public-assets`` + and return the path to the downloaded file. + """ + asset_directory = get_cache_dir() / "vllm_public_assets" + asset_directory.mkdir(parents=True, exist_ok=True) + + asset_path = asset_directory / filename + if not asset_path.exists(): + if s3_prefix is not None: + filename = s3_prefix + "/" + filename + global_http_connection.download_file( + f"{vLLM_S3_BUCKET_URL}/{filename}", + asset_path, + timeout=VLLM_IMAGE_FETCH_TIMEOUT) + + return asset_path diff --git a/vllm/assets/image.py b/vllm/assets/image.py new file mode 100644 index 0000000..5eec78c --- /dev/null +++ b/vllm/assets/image.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass +from typing import Literal + +import torch +from PIL import Image + +from vllm.assets.base import get_vllm_public_assets + +VLM_IMAGES_DIR = "vision_model_images" + + +@dataclass(frozen=True) +class ImageAsset: + name: Literal["stop_sign", "cherry_blossom"] + + @property + def pil_image(self) -> Image.Image: + + image_path = get_vllm_public_assets(filename=f"{self.name}.jpg", + s3_prefix=VLM_IMAGES_DIR) + return Image.open(image_path) + + @property + def image_embeds(self) -> torch.Tensor: + """ + Image embeddings, only used for testing purposes with llava 1.5. + """ + image_path = get_vllm_public_assets(filename=f"{self.name}.pt", + s3_prefix=VLM_IMAGES_DIR) + return torch.load(image_path) diff --git a/vllm/assets/video.py b/vllm/assets/video.py new file mode 100644 index 0000000..05e031a --- /dev/null +++ b/vllm/assets/video.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from functools import lru_cache +from typing import List, Literal + +import numpy as np +import numpy.typing as npt +from huggingface_hub import hf_hub_download +from PIL import Image + +from vllm.multimodal.utils import (sample_frames_from_video, + try_import_video_packages) + +from .base import get_cache_dir + + +@lru_cache +def download_video_asset(filename: str) -> str: + """ + Download and open an image from huggingface + repo: raushan-testing-hf/videos-test + """ + video_directory = get_cache_dir() / "video-eample-data" + video_directory.mkdir(parents=True, exist_ok=True) + + video_path = video_directory / filename + video_path_str = str(video_path) + if not video_path.exists(): + video_path_str = hf_hub_download( + repo_id="raushan-testing-hf/videos-test", + filename=filename, + repo_type="dataset", + cache_dir=video_directory, + ) + return video_path_str + + +def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray: + cv2 = try_import_video_packages() + + cap = cv2.VideoCapture(path) + if not cap.isOpened(): + raise ValueError(f"Could not open video file {path}") + + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + frames = [] + for i in range(total_frames): + ret, frame = cap.read() + if ret: + frames.append(frame) + cap.release() + + frames = np.stack(frames) + frames = sample_frames_from_video(frames, num_frames) + if len(frames) < num_frames: + raise ValueError(f"Could not read enough frames from video file {path}" + f" (expected {num_frames} frames, got {len(frames)})") + return frames + + +def video_to_pil_images_list(path: str, + num_frames: int = -1) -> List[Image.Image]: + cv2 = try_import_video_packages() + frames = video_to_ndarrays(path, num_frames) + return [ + Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + for frame in frames + ] + + +@dataclass(frozen=True) +class VideoAsset: + name: Literal["sample_demo_1.mp4"] + num_frames: int = -1 + + @property + def pil_images(self) -> List[Image.Image]: + video_path = download_video_asset(self.name) + ret = video_to_pil_images_list(video_path, self.num_frames) + return ret + + @property + def np_ndarrays(self) -> npt.NDArray: + video_path = download_video_asset(self.name) + ret = video_to_ndarrays(video_path, self.num_frames) + return ret diff --git a/vllm/attention/__init__.py b/vllm/attention/__init__.py new file mode 100644 index 0000000..2cd4ad3 --- /dev/null +++ b/vllm/attention/__init__.py @@ -0,0 +1,17 @@ +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.layer import Attention +from vllm.attention.selector import get_attn_backend + +__all__ = [ + "Attention", + "AttentionBackend", + "AttentionMetadata", + "AttentionType", + "AttentionMetadataBuilder", + "Attention", + "AttentionState", + "get_attn_backend", +] diff --git a/vllm/attention/__pycache__/__init__.cpython-310.pyc b/vllm/attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..080c3e8 Binary files /dev/null and b/vllm/attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/attention/__pycache__/layer.cpython-310.pyc b/vllm/attention/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000..29f52c1 Binary files /dev/null and b/vllm/attention/__pycache__/layer.cpython-310.pyc differ diff --git a/vllm/attention/__pycache__/selector.cpython-310.pyc b/vllm/attention/__pycache__/selector.cpython-310.pyc new file mode 100644 index 0000000..e6fdecb Binary files /dev/null and b/vllm/attention/__pycache__/selector.cpython-310.pyc differ diff --git a/vllm/attention/backends/__init__.py b/vllm/attention/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/backends/__pycache__/__init__.cpython-310.pyc b/vllm/attention/backends/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..19f3f54 Binary files /dev/null and b/vllm/attention/backends/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/abstract.cpython-310.pyc b/vllm/attention/backends/__pycache__/abstract.cpython-310.pyc new file mode 100644 index 0000000..bb16759 Binary files /dev/null and b/vllm/attention/backends/__pycache__/abstract.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-310.pyc b/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-310.pyc new file mode 100644 index 0000000..15f6fca Binary files /dev/null and b/vllm/attention/backends/__pycache__/blocksparse_attn.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/flash_attn.cpython-310.pyc b/vllm/attention/backends/__pycache__/flash_attn.cpython-310.pyc new file mode 100644 index 0000000..810cba2 Binary files /dev/null and b/vllm/attention/backends/__pycache__/flash_attn.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/flashinfer.cpython-310.pyc b/vllm/attention/backends/__pycache__/flashinfer.cpython-310.pyc new file mode 100644 index 0000000..efdfb57 Binary files /dev/null and b/vllm/attention/backends/__pycache__/flashinfer.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/ipex_attn.cpython-310.pyc b/vllm/attention/backends/__pycache__/ipex_attn.cpython-310.pyc new file mode 100644 index 0000000..9721229 Binary files /dev/null and b/vllm/attention/backends/__pycache__/ipex_attn.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/openvino.cpython-310.pyc b/vllm/attention/backends/__pycache__/openvino.cpython-310.pyc new file mode 100644 index 0000000..2b35a08 Binary files /dev/null and b/vllm/attention/backends/__pycache__/openvino.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/pallas.cpython-310.pyc b/vllm/attention/backends/__pycache__/pallas.cpython-310.pyc new file mode 100644 index 0000000..08e0554 Binary files /dev/null and b/vllm/attention/backends/__pycache__/pallas.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/placeholder_attn.cpython-310.pyc b/vllm/attention/backends/__pycache__/placeholder_attn.cpython-310.pyc new file mode 100644 index 0000000..5ce3aa7 Binary files /dev/null and b/vllm/attention/backends/__pycache__/placeholder_attn.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-310.pyc b/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-310.pyc new file mode 100644 index 0000000..d354027 Binary files /dev/null and b/vllm/attention/backends/__pycache__/rocm_flash_attn.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/torch_sdpa.cpython-310.pyc b/vllm/attention/backends/__pycache__/torch_sdpa.cpython-310.pyc new file mode 100644 index 0000000..d0524c2 Binary files /dev/null and b/vllm/attention/backends/__pycache__/torch_sdpa.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/utils.cpython-310.pyc b/vllm/attention/backends/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..a1e7d37 Binary files /dev/null and b/vllm/attention/backends/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/attention/backends/__pycache__/xformers.cpython-310.pyc b/vllm/attention/backends/__pycache__/xformers.cpython-310.pyc new file mode 100644 index 0000000..19ce106 Binary files /dev/null and b/vllm/attention/backends/__pycache__/xformers.cpython-310.pyc differ diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py new file mode 100644 index 0000000..2bc36ff --- /dev/null +++ b/vllm/attention/backends/abstract.py @@ -0,0 +1,232 @@ +from abc import ABC, abstractmethod +from contextlib import contextmanager +from dataclasses import dataclass, fields +from enum import Enum, auto +from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, + Tuple, Type, TypeVar) + +import torch + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + ModelRunnerInputBuilderBase) + + +class AttentionType(Enum): + DECODER = auto() # Decoder attention between previous layer Q/K/V + ENCODER = auto() # Encoder attention between previous layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_state_cls() -> Type["AttentionState"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise NotImplementedError + + @staticmethod + @abstractmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + raise NotImplementedError + + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: + raise NotImplementedError + + +@dataclass +class AttentionMetadata: + """Attention metadata for prefill and decode batched together.""" + # Total number of prefill requests. + num_prefills: int + # Number of prefill tokens. + num_prefill_tokens: int + # Number of decode tokens. Note that it is equivalent to the number of + # decode requests. + num_decode_tokens: int + # (num_tokens,). The indices of the token slots that input tokens will be + # stored into. E.g., if `slot_mapping` is [35, 2, 17] and the block size + # is 16, the three tokens are stored in the 3rd slot in block 2, 2nd slot + # in block 0, and 1st slot in block 1, respectively. + slot_mapping: torch.Tensor + + @property + @abstractmethod + def prefill_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run prefill + attention.""" + pass + + @property + @abstractmethod + def decode_metadata(self) -> Optional["AttentionMetadata"]: + """Return the attention metadata that's required to run decode + attention.""" + pass + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + """Similar to dataclasses.asdict, but avoids deepcopying.""" + if skip_fields is None: + skip_fields = set() + # Note that if we add dataclasses as fields, they will need + # similar handling. + return { + field.name: getattr(self, field.name) + for field in fields(self) if field.name not in skip_fields + } + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionState(ABC, Generic[T]): + """Holds attention backend-specific objects reused during the + lifetime of the model runner.""" + + @abstractmethod + def __init__(self, runner: "ModelRunnerBase"): + ... + + @abstractmethod + @contextmanager + def graph_capture(self, max_batch_size: int): + """Context manager used when capturing CUDA graphs.""" + yield + + @abstractmethod + def graph_clone(self, batch_size: int) -> "AttentionState[T]": + """Clone attention state to save in CUDA graph metadata.""" + ... + + @abstractmethod + def graph_capture_get_metadata_for_batch( + self, + batch_size: int, + is_encoder_decoder_model: bool = False) -> T: + """Get attention metadata for CUDA graph capture of batch_size.""" + ... + + @abstractmethod + def get_graph_input_buffers( + self, + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + """Get attention-specific input buffers for CUDA graph capture.""" + ... + + @abstractmethod + def prepare_graph_input_buffers( + self, + input_buffers: Dict[str, Any], + attn_metadata: T, + is_encoder_decoder_model: bool = False) -> None: + """In-place modify input buffers dict for CUDA graph replay.""" + ... + + @abstractmethod + def begin_forward(self, model_input: "ModelRunnerInputBase") -> None: + """Prepare state for forward pass.""" + ... + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None: + raise NotImplementedError + + @abstractmethod + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> T: + """Build attention metadata with on-device tensors.""" + raise NotImplementedError + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: T, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py new file mode 100644 index 0000000..c216d19 --- /dev/null +++ b/vllm/attention/backends/blocksparse_attn.py @@ -0,0 +1,444 @@ +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.blocksparse_attention.interface import ( + LocalStridedBlockSparseAttn, get_head_sliding_step) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) + + +@dataclass +class BlocksparseParams: + max_seqlen: int + + # Num q heads per tensor-parallel rank/partition + num_heads: int # per TP partition + # Num kv heads per tensor-parallel rank/partition + num_kv_heads: int + + # block size used for blocksparse attention. + # This is the block_size used in `local_blocks`, `vert_stride`. + block_size: int + + # Number of blocks for local attention, i.e., number of + # local attended tokens / `sparse_block_size` + local_blocks: int + + # Attend to one block per every `vert_stride` blocks. + # Controlling the sparsity + vert_stride: int + """ + If to use the same vertical stride offset for all heads, + i.e., attend to the same block of tokens on all heads. + By default, it is False, i.e., attention on the non-local + blocks depends on the `head_idx`, that is on + blocks satisfying + `(block_idx + head_idx * head_sliding_step + 1) % vert_stride == 0` + where `head_sliding_step=max(1, int(vert_stride / num_total_heads))`, + `block_idx = position_id // sparse_block_size`. + See `..ops.blocksparse_attention.utils:get_sparse_attn_mask` + for more detail. + """ + homo_head: bool = False + + # If within a group, the kv offsets that each q attends is the same or no. + homo_head_group: bool = False + + # Decided by homo_head and homo_head group + head_sliding_step: int = field(init=False) + + # range of q heads to for a TP rank + active_head_range: Tuple = field(init=False) + + def __post_init__(self): + assert self.block_size > 0 + assert self.local_blocks >= 0 + assert self.vert_stride >= 1 + assert self.num_heads % self.num_kv_heads == 0 + + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + total_heads = tp_size * self.num_heads + total_kv_heads = tp_size * self.num_kv_heads + + if self.homo_head: + self.head_sliding_step = 0 + elif self.homo_head_group: + head_sliding_step = get_head_sliding_step(total_kv_heads, + self.vert_stride) + # negative indicates sliding along kv heads, i.e., homo q group + self.head_sliding_step = -head_sliding_step + else: + self.head_sliding_step = get_head_sliding_step( + total_heads, self.vert_stride) + + self.active_head_range = ( + tp_rank * self.num_heads, + (tp_rank + 1) * self.num_heads, + ) + + +class BlocksparseFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: + return BlocksparseFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return BlocksparseFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]: + return BlocksparseFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class BlocksparseFlashAttentionMetadata(AttentionMetadata): + """A copy of Metadata for FlashAttentionBackend, + to avoid having to install flash_attn. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # Max number of query tokens for among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional[ + "BlocksparseFlashAttentionMetadata"] = None + + @property + def prefill_metadata( + self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = BlocksparseFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class BlocksparseFlashAttentionMetadataBuilder( + CommonMetadataBuilder[BlocksparseFlashAttentionMetadata]): + + _metadata_cls = BlocksparseFlashAttentionMetadata + + +class BlocksparseFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + assert blocksparse_params is not None + assert alibi_slopes is None, ValueError( + "Alibi not support for blocksparse flash attention.") + assert sliding_window is None, ValueError( + "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") + + if "num_heads" not in blocksparse_params: + blocksparse_params["num_heads"] = num_heads + if "num_kv_heads" not in blocksparse_params: + blocksparse_params["num_kv_heads"] = num_kv_heads or num_heads + self.blocksparse_params = BlocksparseParams(**blocksparse_params) + self.kv_cache_dtype = kv_cache_dtype + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.alibi_slopes = alibi_slopes + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + self.local_blocks = self.blocksparse_params.local_blocks + self.vert_stride = self.blocksparse_params.vert_stride + self.sparse_block_size = self.blocksparse_params.block_size + self.head_sliding_step = self.blocksparse_params.head_sliding_step + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + total_num_heads = num_heads * self.tp_size + self.bs_attn = LocalStridedBlockSparseAttn( + total_num_heads, + self.blocksparse_params.max_seqlen, + self.blocksparse_params.local_blocks, + self.blocksparse_params.vert_stride, + self.blocksparse_params.block_size, + homo_head=self.blocksparse_params.homo_head, + active_head_range=self.blocksparse_params.active_head_range, + ) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: BlocksparseFlashAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + if prefill_meta := attn_metadata.prefill_metadata: + + # Prompt run. + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + + assert kv_cache.numel() == 0 \ + or prefill_meta.block_tables is None \ + or prefill_meta.block_tables.numel() == 0, \ + "Does not support prefix-enabled attention." + + output = self.bs_attn( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + sm_scale=self.scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + self.blocksparse_params.max_seqlen, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + tp_rank=self.tp_rank, + blocksparse_local_blocks=self.local_blocks, + blocksparse_vert_stride=self.vert_stride, + blocksparse_block_size=self.sparse_block_size, + blocksparse_head_sliding_step=self.head_sliding_step, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py new file mode 100644 index 0000000..c5a90a2 --- /dev/null +++ b/vllm/attention/backends/flash_attn.py @@ -0,0 +1,951 @@ + +"""Attention layer with FlashAttention.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState, + compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.forward_context import get_forward_context +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + +from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func +from ixformer.contrib.vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: Optional[float] = None, + causal: bool = False, + window_size: Optional[List[int]] = None, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + use_sqrt_alibi: Optional[bool] = False +) -> torch.Tensor: + # custom op does not support tuple input + real_window_size: Tuple[int, int] + if window_size is None: + real_window_size = (-1, -1) + else: + assert len(window_size) == 2 + real_window_size = (window_size[0], window_size[1]) + return _flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + softmax_scale=softmax_scale, + causal=causal, + window_size=real_window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + block_table=block_table, + out=out, + sqrt_alibi=use_sqrt_alibi, + ) + + +def flash_attn_with_kvcache( + decode_query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + seq_lens_cpu_tensors: torch.Tensor, + max_context_len: int, + cache_seqlens: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + alibi_slopes: Optional[torch.Tensor] = None, + softcap: float = 0.0, + out: Optional[torch.Tensor] = None, + use_sqrt_alibi: bool = False +) -> torch.Tensor: + return _flash_attn_with_kvcache( + decode_query, + key_cache, + value_cache, + cache_seqlens=cache_seqlens, + block_table=block_table, + softmax_scale=softmax_scale, + causal=causal, + alibi_slopes=alibi_slopes, + softcap=softcap, + max_context_len=max_context_len, + cache_seqlens_cpu=seq_lens_cpu_tensors, + out=out, + use_sqrt_alibi=use_sqrt_alibi + ) + + +def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + """Inductor cannot deal with inplace operations on views. + See https://github.com/pytorch/pytorch/issues/131192 + and https://github.com/pytorch/pytorch/issues/130174 + This is a workaround to hide the view operation from the inductor. + """ + return ops.reshape_and_cache_flash( + key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, + k_scale, v_scale) + + +class FlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 80, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash-attn" + + @staticmethod + def get_impl_cls() -> Type["FlashAttentionImpl"]: + return FlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]: + return FlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class FlashAttentionMetadata(AttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["FlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["FlashAttentionMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence on encoder-decoder + encoder_seq_start_loc: Optional[torch.Tensor] = None + + # Our impl need info fields... + seq_lens_cpu_tensors: Optional[torch.Tensor] = None + encoder_seq_lens_cpu_tensor: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = FlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_query_len=0, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_start_loc=self.encoder_seq_start_loc, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + seq_lens_cpu_tensors = torch.tensor(self.seq_lens[self.num_prefills:],dtype=torch.int32,device="cpu") + encoder_seq_lens_cpu_tensor = torch.tensor(self.encoder_seq_lens,dtype=torch.int32,device="cpu") if self.encoder_seq_lens is not None else None + max_seq_len = self.seq_lens_tensor.max().item() + + self._cached_decode_metadata = FlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_decode_query_len=self.max_decode_query_len, + max_query_len=max_seq_len, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=self.query_start_loc[self.num_prefills:] + if self.query_start_loc is not None else None, + seq_start_loc=self.seq_start_loc[self.num_prefills:] + if self.seq_start_loc is not None else None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + seq_lens_cpu_tensors=seq_lens_cpu_tensors, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + encoder_seq_lens_cpu_tensor=encoder_seq_lens_cpu_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables + ) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + if turn_prefills_into_decodes: + # When Mutli-Step is enabled with Chunked-Prefill, prefills and + # decodes are scheduled together. In the first step, all the + # prefills turn into decodes. This update reflects that + # conversion. + assert self.num_decode_tokens + self.num_prefills == num_seqs + self.num_decode_tokens += self.num_prefills + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.max_prefill_seq_len = 0 + self.max_query_len = 1 + + self.slot_mapping = self.slot_mapping[:num_seqs] + else: + assert self.seq_lens is not None + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class FlashAttentionMetadataBuilder( + AttentionMetadataBuilder[FlashAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + self.has_prefix_cache_hit = False + + self.input_builder = input_builder + self.runner = input_builder.runner + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool, prefix_cache_hit: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if prefix_cache_hit: + # NOTE(woosuk): For flash-attn, the block table should + # include the entries for the incoming prefill tokens. + block_table = block_tables[seq_id] + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + prefix_cache_hit = any([ + inter_data.prefix_cache_hit + for inter_data in self.input_builder.inter_data_list + ]) + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled, + prefix_cache_hit) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + num_seqs = len(seq_lens) + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + return FlashAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class FlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + use_sqrt_alibi: bool = None + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention yet, we will support soon") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.use_sqrt_alibi = use_sqrt_alibi + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + # TODO will support on next week. + self.sliding_window = None + + support_head_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + output = unified_flash_attention( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + attn_type=attn_type, + use_sqrt_alibi=self.use_sqrt_alibi, + ) + + return output + + +def unified_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, + use_sqrt_alibi: bool = False +) -> torch.Tensor: + + current_metadata = get_forward_context() + assert current_metadata is not None + assert isinstance(current_metadata, FlashAttentionMetadata) + attn_metadata: FlashAttentionMetadata = current_metadata + + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + if key is not None: + assert value is not None + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + else: + assert value is None + + if attn_type != AttentionType.ENCODER and kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping.flatten() + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping.flatten() + ops.reshape_and_cache_flash( + key, + value, + key_cache, + value_cache, + updated_slot_mapping, + kv_cache_dtype, + k_scale, + v_scale, + ) + + if attn_type == AttentionType.ENCODER: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + elif attn_type == AttentionType.DECODER: + # Decoder self-attention supports chunked prefill. + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + else: # attn_type == AttentionType.ENCODER_DECODER + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + if attn_metadata.num_encoder_tokens is not None: + num_encoder_tokens = attn_metadata.num_encoder_tokens + else: + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + + # Query for decode. KV is not needed because it is already cached. + output = torch.empty_like(query) + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_encoder_tokens] + value = value[:num_encoder_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + # prefill_output: Optional[torch.Tensor] = None + # decode_output: Optional[torch.Tensor] = None + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.encoder_seq_start_loc if attn_type == AttentionType.ENCODER else prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc if attn_type == AttentionType.DECODER else prefill_meta.encoder_seq_start_loc, + max_seqlen_q=prefill_meta.max_encoder_seq_len if attn_type == AttentionType.ENCODER else prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len if attn_type == AttentionType.DECODER else prefill_meta.max_encoder_seq_len, + softmax_scale=softmax_scale, + causal=attn_type == AttentionType.DECODER, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + out=output[:num_prefill_tokens], + use_sqrt_alibi=use_sqrt_alibi + ) + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=attn_type == AttentionType.DECODER, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + out=output[:num_prefill_tokens], + use_sqrt_alibi=use_sqrt_alibi + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + assert decode_meta.max_decode_query_len is not None + if decode_meta.max_decode_query_len > 1: + flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_k=decode_meta.seq_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=0.0, + block_table=decode_meta.block_tables, + out=output[num_prefill_tokens:], + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + flash_attn_with_kvcache( + decode_query.unsqueeze(1), + key_cache, + value_cache, + seq_lens_cpu_tensors=decode_meta.seq_lens_cpu_tensors if attn_type == AttentionType.DECODER else decode_meta.encoder_seq_lens_cpu_tensor, + max_context_len=decode_meta.max_query_len if attn_type == AttentionType.DECODER else decode_meta.max_encoder_seq_len, + block_table=decode_meta.block_tables if attn_type == AttentionType.DECODER else decode_meta.cross_block_tables, + cache_seqlens=decode_meta.seq_lens_tensor if attn_type == AttentionType.DECODER else decode_meta.encoder_seq_lens_tensor, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=0.0, + out=output[num_prefill_tokens:].unsqueeze(1), + use_sqrt_alibi=use_sqrt_alibi + ).squeeze(1) + + # TODO mv this to flash_attn_with_kvcache when supported. + if logits_soft_cap != 0.0: + output[num_prefill_tokens:] = logits_soft_cap * torch.tanh(output[num_prefill_tokens:] / logits_soft_cap) + + return output.view(-1, num_heads * head_size) + + +# @unified_flash_attention.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py new file mode 100644 index 0000000..be5e7fd --- /dev/null +++ b/vllm/attention/backends/flashinfer.py @@ -0,0 +1,901 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type + +try: + from flashinfer import BatchDecodeWithPagedKVCacheWrapper + from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper + from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper + + from ixformer.contrib.vllm_flash_attn import flash_attn_varlen_func + FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 +except ImportError: + BatchDecodeWithPagedKVCacheWrapper = None + CUDAGraphBatchDecodeWithPagedKVCacheWrapper = None + BatchPrefillWithPagedKVCacheWrapper = None + FLASHINFER_WORKSPACE_BUFFER_SIZE = 0 + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder, + AttentionState, AttentionType) +from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, + compute_slot_mapping_start_idx, + is_block_tables_empty) +from vllm.attention.ops.paged_attn import PagedAttention +from vllm.forward_context import get_forward_context +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) + +if TYPE_CHECKING: + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) + + +class FlashInferBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "flashinfer" + + @staticmethod + def get_impl_cls() -> Type["FlashInferImpl"]: + return FlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return FlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["FlashInferMetadataBuilder"]: + return FlashInferMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["FlashInferState"]: + return FlashInferState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_blocks, 2, block_size, num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 128, 256] + + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + + +class FlashInferState(AttentionState): + + def __init__(self, runner): + self.runner = runner + self._is_graph_capturing = False + self._workspace_buffer = None + self._decode_wrapper = None + self._prefill_wrapper = None + + def _get_workspace_buffer(self): + if self._workspace_buffer is None: + self._workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.runner.device) + return self._workspace_buffer + + def _get_prefill_wrapper(self): + if self._prefill_wrapper is None: + self._prefill_wrapper = None + # Unsupported BatchPrefillWithPagedKVCacheWrapper Now + # self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( + # self._get_workspace_buffer(), "NHD") + return self._prefill_wrapper + + def _get_decode_wrapper(self): + if self._decode_wrapper is None: + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 + self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self._get_workspace_buffer(), + "NHD", + use_tensor_cores=use_tensor_cores) + return self._decode_wrapper + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_decode_wrapper = None + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + self._graph_decode_workspace_buffer = self._get_workspace_buffer() + self._graph_indices_buffer = torch.empty( + max_batch_size * self.runner.cache_config.num_gpu_blocks, + dtype=torch.int32, + device=self.runner.device) + self._graph_indptr_buffer = torch.empty(max_batch_size + 1, + dtype=torch.int32, + device=self.runner.device) + self._graph_last_page_len_buffer = torch.empty( + max_batch_size, dtype=torch.int32, device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + del self._graph_decode_workspace_buffer + del self._graph_indices_buffer + del self._graph_indptr_buffer + del self._graph_last_page_len_buffer + del self._graph_decode_wrapper + + def graph_clone(self, batch_size: int): + assert self._is_graph_capturing + state = self.__class__(self.runner) + state._workspace_buffer = self._graph_decode_workspace_buffer + state._decode_wrapper = self._graph_decode_wrapper + state._prefill_wrapper = self._get_prefill_wrapper() + return state + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] + _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] + + num_qo_heads = (self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config)) + num_kv_heads = self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 + self._graph_decode_wrapper = \ + CUDAGraphBatchDecodeWithPagedKVCacheWrapper( + self._graph_decode_workspace_buffer, _indptr_buffer, + self._graph_indices_buffer, _last_page_len_buffer, "NHD", + use_tensor_cores) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + paged_kv_indptr_tensor_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + paged_kv_indices_tensor_host = torch.arange(0, + batch_size, + dtype=torch.int32) + paged_kv_last_page_len_tensor_host = torch.full((batch_size, ), + self.runner.block_size, + dtype=torch.int32) + query_start_loc_host = torch.arange(0, + batch_size + 1, + dtype=torch.int32) + + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + slot_mapping=self._graph_slot_mapping[:batch_size], + num_prefill_tokens=0, + num_decode_tokens=batch_size, + max_prefill_seq_len=0, + block_tables=self._graph_block_tables, + paged_kv_indptr=paged_kv_indptr_tensor_host, + paged_kv_indices=paged_kv_indices_tensor_host, + paged_kv_last_page_len=paged_kv_last_page_len_tensor_host, + num_qo_heads=num_qo_heads, + num_kv_heads=num_kv_heads, + head_dim=self.runner.model_config.get_head_size(), + page_size=self.runner.block_size, + seq_start_loc=None, + query_start_loc=query_start_loc_host, + device=self.runner.device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=True, + decode_wrapper=self._graph_decode_wrapper, + prefill_wrapper=None) + attn_metadata.begin_forward() + return attn_metadata + + def get_graph_input_buffers(self, + attn_metadata, + is_encoder_decoder_model: bool = False): + return { + "slot_mapping": attn_metadata.slot_mapping, + } + + def prepare_graph_input_buffers(self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False): + return + + def begin_forward(self, model_input): + assert not self._is_graph_capturing + state = self + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + state = (self.runner.graph_runners[model_input.virtual_engine] + [batch_size].attn_state) + model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper( + ) + model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper() + model_input.attn_metadata.begin_forward() + + +@dataclass +class FlashInferMetadata(AttentionMetadata): + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + max_decode_seq_len: int + + use_cuda_graph: bool = True + + prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None + decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None + + # Metadata for the prefill stage + seq_start_loc: Optional[torch.Tensor] = None + query_start_loc: Optional[torch.Tensor] = None + block_tables: Optional[torch.Tensor] = None + + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + + # An example for paged_kv_indices, paged_kv_indptr: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + # The indptr of the paged kv cache, shape: [batch_size + 1] + paged_kv_indptr: Optional[torch.Tensor] = None + # The page indices of the paged kv cache + paged_kv_indices: Optional[torch.Tensor] = None + # The number of entries in the last page of each request in + # the paged kv cache, shape: [batch_size] + paged_kv_last_page_len: Optional[torch.Tensor] = None + # The number of query/output heads + num_qo_heads: Optional[int] = None + # The number of key/value heads + num_kv_heads: Optional[int] = None + # The dimension of the attention heads + head_dim: Optional[int] = None + # Block size of vllm + page_size: Optional[int] = None + # The data type of the paged kv cache + data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None + device: torch.device = torch.device("cuda") + is_profile_run: bool = False + + def __post_init__(self): + # Refer to + # https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157 + supported_head_sizes = FlashInferBackend.get_supported_head_sizes() + if self.head_dim is not None and self.head_dim \ + not in supported_head_sizes: + raise ValueError( + f"Only {supported_head_sizes} are supported for head_dim,", + f"received {self.head_dim}.") + + def begin_forward(self): + if self.num_prefill_tokens > 0: + if self.paged_kv_indices is None: + return + + # assert self.prefill_wrapper is not None + assert self.query_start_loc is not None + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None + batch_size = self.query_start_loc.shape[0] - 1 + assert batch_size >= 0 + # We will use flash attention for profiling to + # determine the number of blocks. Therefore, + # we don't need to prepare the input for flashinfer for profile run. + if not self.is_profile_run: + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + + # Unsupported BatchPrefillWithPagedKVCacheWrapper Now + # self.prefill_wrapper.end_forward() + # self.prefill_wrapper.begin_forward( + # self.query_start_loc, self.paged_kv_indptr, + # self.paged_kv_indices, self.paged_kv_last_page_len, + # self.num_qo_heads, self.num_kv_heads, self.head_dim, + # self.page_size) + else: + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + + assert self.decode_wrapper is not None + self.decode_wrapper.end_forward() + self.decode_wrapper.begin_forward( + self.paged_kv_indptr, + self.paged_kv_indices, + self.paged_kv_last_page_len, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + self.page_size, + # Disable flashinfer's pos encoding and use vllm's rope. + pos_encoding_mode="NONE", + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) + + def asdict_zerocopy(self, + skip_fields: Optional[Set[str]] = None + ) -> Dict[str, Any]: + if skip_fields is None: + skip_fields = set() + # We need to skip the prefill/decode_wrapper field since it cannot be + # broadcasted with nccl when TP is enabled. + skip_fields.add('prefill_wrapper') + skip_fields.add('decode_wrapper') + return super().asdict_zerocopy(skip_fields) + + @property + def prefill_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["FlashInferMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + return None + + return self + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with flashinfer yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + + +class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + # Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout + # for the precise definition of the following fields. + # An example: + # request 1, page indices [0, 5, 8] + # request 2, page indices [1, 6, 7] + # request 3, page indices [3, 4] + # paged_kv_indices is a concatenation of page indices of all requests: + # [0, 5, 8, 1, 6, 7, 3, 4] + # paged_kv_indptr is used to index into paged_kv_indices: + # [0, 3, 6, 8] + self.paged_kv_indices: List[int] = [] + # 0 at the beginning of paged_kv_indptr indicates the start of the + # first request’s page indices in the paged_kv_indices list. + self.paged_kv_indptr: List[int] = [0] + # paged_kv_last_page_len is the length of the last page of each request + self.paged_kv_last_page_len: List[int] = [] + self.total_blocks = 0 + self.is_profile_run: bool = False + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + 2. block table. + 3. slot mapping. + """ + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + is_profile_run = is_block_tables_empty(block_tables) + + # Compute slot mapping. + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + # It is not necessary to add paged_kv_indices, paged_kv_indptr, + # and paged_kv_last_page_len for profile run because we will + # create dummy inputs. + if is_profile_run: + self.is_profile_run = is_profile_run + return + + block_table = block_tables[seq_id] + self._update_paged_kv_tensors(block_table, seq_len) + + def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): + # Get the number of valid blocks based on sequence length. + # If seq_len = 16, block_size = 16, + # block_table_bound is 1 with 1 valid block. + # If seq_len = 15, block_size = 16, + # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) + block_table_bound = seq_len // self.block_size + 1 \ + if seq_len % self.block_size != 0 \ + else seq_len // self.block_size + self.paged_kv_indices.extend(block_table[:block_table_bound]) + self.paged_kv_indptr.append(self.paged_kv_indptr[-1] + + block_table_bound) + + last_page_len = seq_len % self.block_size + if last_page_len == 0: + last_page_len = self.block_size + self.paged_kv_last_page_len.append(last_page_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] + for i, block_table in enumerate(self.block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + + last_paged_kv_indptr = self.paged_kv_indptr[-1] + self.paged_kv_indptr.extend([last_paged_kv_indptr] * + cuda_graph_pad_size) + self.paged_kv_last_page_len.extend([0] * cuda_graph_pad_size) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + + assert device is not None + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) + paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, + device="cpu", + dtype=torch.int) + paged_kv_indptr_tensor = torch.tensor(self.paged_kv_indptr, + device="cpu", + dtype=torch.int) + paged_kv_last_page_len_tensor = torch.tensor( + self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) + else: + paged_kv_indices_tensor = None + paged_kv_indptr_tensor = None + paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None + + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + + return FlashInferMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max(seq_lens), + block_tables=block_tables, + paged_kv_indptr=paged_kv_indptr_tensor, + paged_kv_indices=paged_kv_indices_tensor, + paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, + num_qo_heads=self.runner.model_config.get_num_attention_heads( + self.runner.parallel_config), + num_kv_heads=self.runner.model_config.get_num_kv_heads( + self.runner.parallel_config), + head_dim=self.runner.model_config.get_head_size(), + page_size=self.block_size, + seq_start_loc=seq_start_loc, + query_start_loc=query_start_loc, + device=device, + data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, + use_cuda_graph=use_captured_graph, + is_profile_run=self.is_profile_run) + + +class FlashInferImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is not None: + raise ValueError("Sliding window is not supported in FlashInfer.") + self.sliding_window = (-1, -1) + self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: FlashInferMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashInfer.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + + # return torch.ops.vllm.unified_flash_infer( + return unified_flash_infer( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) + + +# @torch.library.custom_op("vllm::unified_flash_infer", +# mutates_args=["kv_cache"]) +def unified_flash_infer( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + + current_metadata = get_forward_context() + assert current_metadata is not None + assert isinstance(current_metadata, FlashInferMetadata) + attn_metadata: FlashInferMetadata = current_metadata + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if attn_metadata.num_prefill_tokens > 0: + assert attn_metadata.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if attn_metadata.num_decode_tokens > 0: + assert attn_metadata.num_prefill_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flashinfer( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + query = query.contiguous() # Flashinfer requires query to be contiguous + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + use_infer_inferface = True + if use_infer_inferface or kv_cache.numel() == 0: + output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + output = prefill_meta.prefill_wrapper.forward( + query, kv_cache, logits_soft_cap=logits_soft_cap, causal=True) + else: + assert attn_metadata.decode_metadata is not None + assert attn_metadata.decode_metadata.decode_wrapper is not None + output = attn_metadata.decode_metadata.decode_wrapper.forward( + query, + kv_cache, + sm_scale=softmax_scale, + logits_soft_cap=logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale) + return output.view(num_tokens, hidden_size) + + +# @unified_flash_infer.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query).contiguous() diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py new file mode 100644 index 0000000..7398732 --- /dev/null +++ b/vllm/attention/backends/ipex_attn.py @@ -0,0 +1,381 @@ +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) + +_PARTITION_SIZE = 512 + + +class IpexAttnBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "ipex-attn" + + @staticmethod + def get_impl_cls() -> Type["IpexAttnBackendImpl"]: + return IpexAttnBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["IpexAttnMetadata"]: + return IpexAttnMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + from vllm._ipex_ops import ipex_ops as ops + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +@dataclass +class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for IpexAttnBackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + seqlen_q: Optional[torch.Tensor] + max_seqlen: Optional[int] + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + + @property + def prefill_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["IpexAttnMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + +class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("IPEX backend does not support logits_soft_cap.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "IPEX backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def split_kv_cache( + self, + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 1 + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IpexAttnMetadata, # type: ignore + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with IPEX varlen_attention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert k_scale == 1.0 and v_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = self.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + ipex_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping.flatten(), + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + if attn_metadata.is_prompt: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or attn_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=1) + + if attn_metadata.attn_bias is None: + if self.alibi_slopes is not None: + att_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + att_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, None, dtype=query.dtype) + attn_metadata.attn_bias = att_masks + + output = torch.empty( + (num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + ipex_ops.varlen_attention(query, + key, + value, + output, + attn_metadata.seqlen_q, + attn_metadata.seqlen_q, + attn_metadata.max_seqlen, + attn_metadata.max_seqlen, + pdropout=0.0, + softmax_scale=self.scale, + zero_tensors=False, + is_causal=True, + return_softmax=False, + gen_=None) + else: + # prefix-enabled attention + raise RuntimeError( + "IPEX backend doesn't support prefix decoding.") + + else: + # Decoding run. + max_seq_len = attn_metadata.max_decode_seq_len + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory + # shortage. + use_v1 = (max_seq_len <= 8192 and + (max_num_partitions == 1 or num_seqs * num_heads > 512)) + if use_v1: + # Run PagedAttention V1. + ipex_ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ipex_ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + attn_metadata.block_tables, + attn_metadata.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype, + device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py new file mode 100644 index 0000000..8b36230 --- /dev/null +++ b/vllm/attention/backends/openvino.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +from typing import List, Tuple, Type + +import openvino as ov +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.backends.utils import CommonAttentionState + + +def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor, + src_offset: int, dst_offset: int) -> None: + + def create_roi_tensor( + tensor: ov.Tensor, + block_number: int, + ) -> ov.Tensor: + roi_begin = ov.runtime.Coordinate([0, 0, 0, 0]) + roi_end = ov.runtime.Coordinate(tensor.get_shape()) + + roi_begin[0] = block_number + roi_end[0] = block_number + 1 + + if isinstance(tensor, ov.Tensor): + return ov.Tensor(tensor, roi_begin, roi_end) + else: + return ov.RemoteTensor(tensor, roi_begin, roi_end) + + src_roi_tensor = \ + create_roi_tensor(src_tensor, src_offset) + dst_roi_tensor = \ + create_roi_tensor(dst_tensor, dst_offset) + src_roi_tensor.copy_to(dst_roi_tensor) + + +class OpenVINOAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "openvino" + + @staticmethod + def get_impl_cls(): + # OpenVINO implements PagedAttention as part of the Optimum + # exported model + raise NotImplementedError + + @staticmethod + def make_metadata(*args, **kwargs) -> "AttentionMetadata": + raise NotImplementedError + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def make_openvino_metadata(*args, **kwargs) -> "OpenVINOAttentionMetadata": + return OpenVINOAttentionMetadata(*args, **kwargs) + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, num_kv_heads, block_size, head_size) + + @staticmethod + def swap_blocks( + src_tensor: ov.Tensor, + dst_tensor: ov.Tensor, + src_to_dists: List[Tuple[int, int]], + ) -> None: + for src, dst in src_to_dists: + copy_cache_block(src_tensor, dst_tensor, src, dst) + + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + src_to_dists: List[Tuple[int, int]], + ) -> None: + for src, dst in src_to_dists: + for key_cache, value_cache in kv_caches: + copy_cache_block(key_cache, key_cache, src, dst) + copy_cache_block(value_cache, value_cache, src, dst) + + +@dataclass +class OpenVINOAttentionMetadata: + """Metadata for OpenVINOAttentionBackend. + + Basic terms used below: + - batch_size_in_sequences - total number of sequences to execute​ + - prompt_lens – per sequence size number of scheduled tokens​ + - batch_size_in_tokens = sum(prompt_lens)​ + - max_context_len = max(context_lens)​ + - max_num_blocks = div_up(max_context_len / BLOCK_SIZE)​ + - num_blocks – total number of blocks in block_indices​ + """ + + # Describes past KV cache size for each sequence within a batch + # Shape: [batch_size_in_sequences] + # Type: i32​ + past_lens: torch.Tensor + + # Describes start indices of input / speculative tokens from + # current sequences within a batch sequence​ + # Shape: [batch_size_in_sequences + 1]​ + # Type: i32 + subsequence_begins: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # indices along 0th dimension in key_cache and value_cache inputs​ + # Shape: [num_blocks] + # Type: i32​ + block_indices: torch.Tensor + + # Describes block tables for each sequence within a batch​ - + # for i-th element, it is an index in block_indices with the + # first block belonging to i-th sequence​ + # Shape: [batch_size_in_sequences + 1] + # Type: i32​ + block_indices_begins: torch.Tensor + + # Describes max context length + # Shape: scalar + # Type: i32 + max_context_len: torch.Tensor diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py new file mode 100644 index 0000000..8671660 --- /dev/null +++ b/vllm/attention/backends/pallas.py @@ -0,0 +1,260 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +import torch_xla.experimental.custom_kernel # Required to register custom ops. + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState + + +class PallasAttentionBackend(AttentionBackend): + + @staticmethod + def get_impl_cls() -> Type["PallasAttentionBackendImpl"]: + return PallasAttentionBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PallasMetadata"]: + return PallasMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (num_kv_heads, num_blocks, block_size, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + raise RuntimeError("swap_blocks is not used for the TPU backend.") + + @torch.compile(backend="openxla") + @staticmethod + def copy_blocks( + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + src_to_dists: Tuple[torch.Tensor, torch.Tensor], + ) -> None: + src_indices, dst_indices = src_to_dists + for k_cache, v_cache in kv_caches: + torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True) + k_cache[:, dst_indices] = k_cache[:, src_indices] + torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True) + v_cache[:, dst_indices] = v_cache[:, src_indices] + + +@dataclass +class PallasMetadata(AttentionMetadata): + + # Currently, input sequences can only contain all prefills + # or all decoding. + block_tables: Optional[torch.Tensor] = None + context_lens: Optional[torch.Tensor] = None + + @property + def prefill_metadata(self) -> Optional["PallasMetadata"]: + if self.num_prefills == 0: + return None + + assert self.num_decode_tokens == 0 + assert self.block_tables is None + assert self.context_lens is None + return self + + @property + def decode_metadata(self) -> Optional["PallasMetadata"]: + if self.num_decode_tokens == 0: + return None + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.block_tables is not None + assert self.context_lens is not None + return self + + +class PallasAttentionBackendImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if head_size % 128 != 0: + raise NotImplementedError("Head size must be a multiple of 128.") + if alibi_slopes is not None: + raise NotImplementedError("Alibi slopes is not supported.") + if sliding_window is not None: + raise NotImplementedError("Sliding window is not supported.") + if kv_cache_dtype != "auto": + raise NotImplementedError("FP8 KV cache dtype is not supported.") + if blocksparse_params is not None: + raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") + + if torch_xla.tpu.version() < 4: + raise NotImplementedError("TPU version must be 4 or higher.") + + self.megacore_mode = None + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + + if (("lite" not in tpu_type) and ("v6" not in tpu_type)): + if self.num_kv_heads % 2 == 0: + self.megacore_mode = "kv_head" + else: + # NOTE(woosuk): If the batch size is not a multiple of 2, the + # megacore mode will be None. + self.megacore_mode = "batch" + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor], + attn_metadata: PallasMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with Pallas attention. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size] + kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size] + NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor + with shape [0] for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + assert k_scale == 1.0 and v_scale == 1.0 + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + batch_size, seq_len, hidden_size = query.shape + query = query.view(batch_size, seq_len, self.num_heads, self.head_size) + key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) + value = value.view(batch_size, seq_len, self.num_kv_heads, + self.head_size) + + if kv_cache[0].numel() > 0: + slot_mapping = attn_metadata.slot_mapping + key_cache, value_cache = kv_cache + write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping) + + query = query * self.scale + if attn_metadata.num_prefills > 0: + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) + key = key.view(batch_size, seq_len, self.num_heads, + self.head_size) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention requires [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Decoding run. + assert kv_cache[0].numel() > 0 + + pages_per_compute_block = 16 # TODO(woosuk): Tune this value. + if self.megacore_mode == "batch" and batch_size % 2 != 0: + megacore_mode = None + else: + megacore_mode = self.megacore_mode + + # NOTE(woosuk): A temporary workaround to avoid the error: + # "xla::paged_attention() Expected a value of type 'str' for + # argument 'megacore_mode' but instead found type 'NoneType'." + if megacore_mode is not None: + output = torch.ops.xla.paged_attention( + query.squeeze(dim=1), + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + ) + else: + output = torch.ops.xla.paged_attention( + query.squeeze(dim=1), + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + pages_per_compute_block, + ) + + # Reshape the output tensor. + return output.reshape(batch_size, seq_len, hidden_size) + + +def write_to_kv_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True) + + key = key.flatten(0, 2) + value = value.flatten(0, 2) + key_cache = key_cache.flatten(0, 2) + value_cache = value_cache.flatten(0, 2) + key_cache.index_copy_(0, slot_mapping, key) + value_cache.index_copy_(0, slot_mapping, value) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py new file mode 100644 index 0000000..3987986 --- /dev/null +++ b/vllm/attention/backends/placeholder_attn.py @@ -0,0 +1,321 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Type + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) +from vllm.attention.backends.utils import CommonAttentionState + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + +# Placeholder attention backend for models like Mamba and embedding models that +# lack attention. + + +class PlaceholderAttentionBackend(AttentionBackend): + """Placeholder backend for when no attention is needed.""" + + @staticmethod + def get_name() -> str: + return "placeholder-attn" + + @staticmethod + def get_impl_cls() -> Type["PlaceholderAttentionImpl"]: + return PlaceholderAttentionImpl + + @staticmethod + def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]: + return PlaceholderAttentionMetadataBuilder + + @staticmethod + def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]: + return PlaceholderAttentionMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (1, 1, 1, 1, 1) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + return + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + return + + +@dataclass +class PlaceholderAttentionMetadata(AttentionMetadata): + """Attention metadata for prefill and decode batched together.""" + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # Maximum query length in the batch. + max_query_len: Optional[int] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] + + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + _cached_prefill_metadata: Optional["PlaceholderAttentionMetadata"] = None + _cached_decode_metadata: Optional["PlaceholderAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.seq_start_loc is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_prefill_metadata = PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_decode_query_len=0, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=block_tables, + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.seq_lens_tensor is not None + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + self._cached_decode_metadata = PlaceholderAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_decode_query_len=self.max_decode_query_len, + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + +class PlaceholderAttentionMetadataBuilder( + AttentionMetadataBuilder[PlaceholderAttentionMetadata]): + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + """Add a sequence group to the metadata. Specifically update/append + 1. context length. + """ + is_prompt = inter_data.is_prompt + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + logits_soft_cap = getattr(self.runner.model_config.hf_config, + "attn_logit_softcapping", None) + if logits_soft_cap is not None: + raise ValueError( + "Please use Flashinfer backend for models with logits_soft_cap" + " (i.e., Gemma-2). Otherwise, the output might be wrong." + " Set Flashinfer backend by " + "export VLLM_ATTENTION_BACKEND=FLASHINFER.") + + max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + max_decode_query_len = max(decode_query_lens) + else: + max_decode_query_len = 1 + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + num_decode_tokens = batch_size + + assert max_query_len > 0, ("query_lens: {}".format(query_lens)) + + context_lens_tensor = torch.tensor(self.context_lens, + dtype=torch.int, + device=device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=device) + query_lens_tensor = torch.tensor(query_lens, + dtype=torch.long, + device=device) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + # Placeholders + slot_mapping = torch.empty(0) + block_tables = torch.empty(0) + + return PlaceholderAttentionMetadata( + num_prefills=self.num_prefills, + slot_mapping=slot_mapping, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_decode_query_len=max_decode_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class PlaceholderAttentionImpl(AttentionImpl): + + def __init__(self, *args, **kwargs) -> None: + return + + def forward(self, *args, **kwargs) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py new file mode 100644 index 0000000..682eac5 --- /dev/null +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -0,0 +1,668 @@ +"""Attention layer ROCm GPUs.""" +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +logger = init_logger(__name__) + +_PARTITION_SIZE_ROCM = 512 +_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName + + +class ROCmFlashAttentionBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + + @staticmethod + def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: + return ROCmFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return ROCmFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["ROCmFlashAttentionMetadataBuilder"]: + return ROCmFlashAttentionMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for FlashAttentionBackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None + _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None + + @property + def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + return self._cached_prefill_metadata + + assert self.seq_lens is not None + assert self.seq_lens_tensor is not None + assert self.query_start_loc is not None + assert self.context_lens_tensor is not None + assert self.block_tables is not None + assert self.seq_start_loc is not None + + self._cached_prefill_metadata = ROCmFlashAttentionMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=self.slot_mapping[:self.num_prefill_tokens], + seq_lens=self.seq_lens[:self.num_prefills], + seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=self.query_start_loc[:self.num_prefills + 1], + seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], + context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + block_tables=self.block_tables[:self.num_prefills], + use_cuda_graph=False, + ) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + return self._cached_decode_metadata + assert self.block_tables is not None + assert self.seq_lens_tensor is not None + + self._cached_decode_metadata = ROCmFlashAttentionMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=self.slot_mapping[self.num_prefill_tokens:], + seq_lens=None, + seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + max_query_len=None, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self.block_tables[self.num_prefills:], + use_cuda_graph=self.use_cuda_graph, + ) + return self._cached_decode_metadata + + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): + """ + Update metadata in-place to advance one decode step. + """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + assert self.num_prefills == 0 + assert self.num_prefill_tokens == 0 + assert self.num_decode_tokens == num_seqs + assert self.slot_mapping.shape == (num_seqs, ) + + assert self.seq_lens is not None + assert len(self.seq_lens) == num_seqs + assert self.seq_lens_tensor is not None + assert self.seq_lens_tensor.shape == (num_seqs, ) + assert self.max_query_len == 1 + assert self.max_prefill_seq_len == 0 + assert self.max_decode_seq_len == max(self.seq_lens) + + assert self.query_start_loc is not None + assert self.query_start_loc.shape == (num_queries + 1, ) + assert self.seq_start_loc is not None + assert self.seq_start_loc.shape == (num_seqs + 1, ) + + assert self.context_lens_tensor is not None + assert self.context_lens_tensor.shape == (num_queries, ) + + assert self.block_tables is not None + assert self.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + self.seq_lens[i] += 1 + self.max_decode_seq_len = max(self.seq_lens) + + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) + + +class ROCmFlashAttentionMetadataBuilder( + CommonMetadataBuilder[ROCmFlashAttentionMetadata]): + + _metadata_cls = ROCmFlashAttentionMetadata + + +def _make_alibi_bias(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: + attn_biases = [] + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + +class ROCmFlashAttentionImpl(AttentionImpl): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prompt_tokens -------------->| + |<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->| + + Otherwise, the layout is as follows: + |<------------------ num_generation_tokens (M) ----------------->| + |<--generation_0-->|..........|<--generation_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens ----------->| + |<-prompt_0->|...|<-prompt_N-1->|<-generation_0->|...|<-generation_M-1->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "ROCmFlashAttention does not support attention logits soft " + "capping.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + + self.use_naive_attn = False + # NOTE: Allow for switching between Triton and CK. Defaulting to triton. + self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN + if self.use_triton_flash_attn: + from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 + triton_attention) + self.attn_func = triton_attention + logger.debug("Using Triton FA in ROCmBackend") + if self.sliding_window != (-1, -1): + logger.warning("ROCm Triton FA does not currently support " + "sliding window attention. If using half " + "precision, please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") + else: + # if not using triton, navi3x/navi21/navi10 do not use flash-attn + # either + if not current_platform.has_device_capability(90): + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + self.attn_func = _sdpa_attention + logger.debug("Using naive attention in ROCmBackend") + + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return (x[:, :, + None, :].expand(tokens, n_kv_heads, n_rep, + head_dim).reshape(tokens, n_kv_heads * n_rep, + head_dim)) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: ROCmFlashAttentionMetadata, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "ROCmFlashAttentionImpl") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if kv_cache.numel() > 0: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + PagedAttention.write_to_paged_cache( + key, + value, + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + assert prefill_meta.seq_lens is not None + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # triton attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + attn_masks = None + if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=False) # type: ignore + out, _ = self.attn_func( + query, + key, + value, + None, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prefill_seq_len, + prefill_meta.max_prefill_seq_len, + True, + self.scale, + attn_masks[0][None] + if attn_masks is not None else None, + ) + elif self.use_naive_attn: + if self.num_kv_heads != self.num_heads: + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=True) # type: ignore + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + # sdpa math backend attention + out = self.attn_func( + query, + key, + value, + prefill_meta.seq_lens, + num_tokens, + self.num_heads, + self.head_size, + self.scale, + attn_masks, + ) + else: + out = self.attn_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) + + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + else: + # prefix-enabled attention + output[:num_prefill_tokens] = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window[0], + k_scale, + v_scale, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = decode_query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // self.num_kv_heads + use_custom = _use_rocm_custom_paged_attention( + decode_query.dtype, head_size, block_size, gqa_ratio, + decode_meta.max_decode_seq_len) + if use_custom: + max_seq_len = decode_meta.max_decode_seq_len + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_rocm( + output[num_prefill_tokens:], + exp_sums, + max_logits, + tmp_output, + decode_query, + key_cache, + value_cache, + self.num_kv_heads, + self.scale, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + block_size, + max_seq_len, + self.alibi_slopes, + self.kv_cache_dtype, + k_scale, + v_scale, + ) + else: + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + decode_meta.block_tables, + decode_meta.seq_lens_tensor, + decode_meta.max_decode_seq_len, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + + +def _sdpa_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + seq_lens: List[int], + num_tokens: int, + num_heads: int, + head_size: int, + scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, +) -> torch.Tensor: + start = 0 + output = torch.empty((num_tokens, num_heads, head_size), + dtype=query.dtype, + device=query.device) + + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + with torch.backends.cuda.sdp_kernel(enable_math=True, + enable_flash=False, + enable_mem_efficient=False): + sub_out = torch.nn.functional.scaled_dot_product_attention( + query[:, start:end, :], + key[:, start:end, :], + value[:, start:end, :], + dropout_p=0.0, + is_causal=attn_masks is None, + attn_mask=attn_masks[i] if attn_masks else None, + scale=scale).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + return output + + +def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py new file mode 100644 index 0000000..ef8d576 --- /dev/null +++ b/vllm/attention/backends/torch_sdpa.py @@ -0,0 +1,547 @@ +""" Attention layer with torch scaled_dot_product_attention + and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.utils import is_cpu + +if is_cpu(): + try: + from vllm.attention.ops.ipex_attn import PagedAttention + except ImportError: + from vllm.attention.ops.paged_attn import PagedAttention +else: + from vllm.attention.ops.paged_attn import PagedAttention + + +class TorchSDPABackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch-sdpa" + + @staticmethod + def get_impl_cls() -> Type["TorchSDPABackendImpl"]: + return TorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return TorchSDPAMetadata + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for TorchSDPABackend. + """ + # Currently, input sequences can only contain all prompts + # or all decoding. True if all sequences are prompts. + is_prompt: bool + slot_mapping: torch.Tensor + seq_lens: Optional[List[int]] + + # Begin encoder attn & enc/dec cross-attn fields... + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[torch.Tensor]] = None + self.encoder_attn_bias: Optional[List[torch.Tensor]] = None + self.cross_attn_bias: Optional[List[torch.Tensor]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_decode_tokens == 0: + assert self.num_prefills > 0 + return self + + return None + + @property + def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: + # Currently chunked prefill is not supported + if self.num_prefills > 0: + assert self.num_decode_tokens == 0 + return None + + return self + + def get_seq_lens( + self, + attn_type: AttentionType, + ): + ''' + Extract appropriate sequence lengths from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate sequence lengths tensor for query + * Appropriate sequence lengths tensor for key & value + ''' + + if attn_type == AttentionType.DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.seq_lens + elif attn_type == AttentionType.ENCODER: + seq_lens_q = self.encoder_seq_lens + seq_lens_kv = self.encoder_seq_lens + elif attn_type == AttentionType.ENCODER_DECODER: + seq_lens_q = self.seq_lens + seq_lens_kv = self.encoder_seq_lens + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + return seq_lens_q, seq_lens_kv + + def get_attn_bias( + self, + attn_type: AttentionType, + ) -> Optional[List[torch.Tensor]]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if attn_type == AttentionType.DECODER: + return self.attn_bias + elif attn_type == AttentionType.ENCODER: + return self.encoder_attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + return self.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def set_attn_bias( + self, + attn_bias: List[torch.Tensor], + attn_type: AttentionType, + ) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if attn_type == AttentionType.DECODER: + self.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + self.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + self.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + def get_seq_len_block_table_args( + self, + attn_type: AttentionType, + ) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + return (self.seq_lens_tensor, self.max_decode_seq_len, + self.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + self.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (self.encoder_seq_lens_tensor, self.max_encoder_seq_len, + None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + supported_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in supported_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {supported_head_sizes}.") + if kv_cache_dtype != "auto": + raise NotImplementedError( + "Torch SDPA backend does not support FP8 KV cache. " + "Please use xFormers backend instead.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: TorchSDPAMetadata, # type: ignore + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with torch SDPA and PagedAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert k_scale == 1.0 and v_scale == 1.0 + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + k_scale, v_scale) + + if attn_type != AttentionType.ENCODER: + # Decoder self-attention supports chunked prefill. + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + else: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + + if attn_type == AttentionType.DECODER: + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + assert attn_metadata.seq_lens is not None + if (kv_cache.numel() == 0 + or prefill_meta.block_tables.numel() == 0): + output = self._run_sdpa_forward(query, + key, + value, + prefill_meta, + attn_type=attn_type) + else: + # prefix-enabled attention + raise RuntimeError( + "Torch SDPA backend doesn't support prefix decoding.") + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = decode_meta.get_seq_len_block_table_args(attn_type) + + output = PagedAttention.forward_decode( + query, + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_sdpa_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: TorchSDPAMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ): + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + attn_masks = attn_metadata.get_attn_bias(attn_type) + if attn_masks is None: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, query.dtype, + attn_metadata.seq_lens) # type: ignore + elif self.sliding_window is not None: + assert attn_metadata.seq_lens is not None + attn_masks = _make_sliding_window_bias( + attn_metadata.seq_lens, self.sliding_window, + query.dtype) # type: ignore + else: + seq_lens, _ = attn_metadata.get_seq_lens(attn_type) + attn_masks = [None] * len(seq_lens) + attn_metadata.set_attn_bias(attn_masks, attn_type) + + output = torch.empty_like(query) + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + causal_attn = (attn_type == AttentionType.DECODER) + + seq_lens_q, seq_lens_kv = attn_metadata.get_seq_lens(attn_type) + start_q, start_kv = 0, 0 + for seq_len_q, seq_len_kv, mask in zip(seq_lens_q, seq_lens_kv, + attn_masks): + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + sub_out = scaled_dot_product_attention( + query[None, :, start_q:end_q, :], + key[None, :, start_kv:end_kv, :], + value[None, :, start_kv:end_kv, :], + attn_mask=mask, + dropout_p=0.0, + is_causal=causal_attn and not self.need_mask, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start_q:end_q, :, :] = sub_out + start_q, start_kv = end_q, end_kv + return output + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat((num_heads, 1, 1)) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) + attn_biases.append((bias + inf_mask).to(dtype)) + + return attn_biases + + +def _make_sliding_window_bias( + seq_lens: List[int], + window_size: Optional[int], + dtype: torch.dtype, +) -> List[torch.Tensor]: + attn_biases: List[torch.Tensor] = [] + for seq_len in seq_lens: + tensor = torch.full( + (1, seq_len, seq_len), + dtype=dtype, + fill_value=1, + ) + shift = 0 + mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore + if window_size is not None: + mask = torch.triu(mask, diagonal=shift - window_size + 1) + mask = torch.log(mask) + attn_biases.append(mask.to(dtype)) + + return attn_biases diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py new file mode 100644 index 0000000..53e3a53 --- /dev/null +++ b/vllm/attention/backends/utils.py @@ -0,0 +1,442 @@ +"""Attention backend utils""" +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, List, Type, TypeVar, Union + +import numpy as np +import torch + +from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder, + AttentionState) +from vllm.utils import async_tensor_h2d, make_tensor_with_pad + +if TYPE_CHECKING: + from vllm.worker.model_runner_base import ModelRunnerBase + +# Error string(s) for encoder/decoder +# unsupported attention scenarios +STR_NOT_IMPL_ENC_DEC_ROCM_HIP = ("ROCm/HIP is not currently supported " + "with encoder/decoder models.") + +PAD_SLOT_ID = -1 + +# Switch to numpy implementation of compute_slot_mapping +# if we have at least this many elements. Could be tuned further. +_COMPUTE_SLOT_MAPPING_NUMPY_NUMEL = 256 + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUBuilder + + +def is_block_tables_empty(block_tables: Union[None, Dict]): + """ + Check if block_tables is None or a dictionary with all None values. + """ + if block_tables is None: + return True + return (isinstance(block_tables, dict) + and all(value is None for value in block_tables.values())) + + +def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, + context_len: int, sliding_window: int, + use_v2_block_manager: bool): + """ + Compute the start index of slot mapping. + """ + start_idx = 0 + if is_prompt and sliding_window is not None: + assert use_v2_block_manager or context_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention in V1 block manager") + # When prefill, we use it to not write slots to kv cache + # to save memory. + start_idx = max(0, query_len - sliding_window) + return start_idx + + +def _compute_slot_mapping_python(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + for i in range(range_start, range_end): + block_number = block_table[i // block_size] + block_offset = i % block_size + slot = block_number * block_size + block_offset + slot_mapping.append(slot) + + +def _compute_slot_mapping_numpy(slot_mapping: List[int], + block_table: List[int], range_start: int, + range_end: int, block_size: int): + block_table_array = np.array(block_table) + idx = np.arange(range_start, range_end) + block_offset = idx % block_size + idx //= block_size + seq_slot_mapping_array = block_table_array[idx] + seq_slot_mapping_array *= block_size + seq_slot_mapping_array += block_offset + slot_mapping.extend(seq_slot_mapping_array) + + +def compute_slot_mapping(is_profile_run: bool, slot_mapping: List[int], + seq_id: int, seq_len: int, context_len: int, + start_idx: int, block_size: int, + block_tables: Dict[int, List[int]]): + """ + Compute slot mapping. + """ + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + slot_mapping.extend([PAD_SLOT_ID] * seq_len) + return + + # Mask the [0, start_idx) tokens of the prompt with + # PAD_SLOT_ID, where start_idx is max(0, seq_len - + # sliding_window). For example, if the prompt len is 10, + # sliding window is 8, and block size is 4, the first two + # tokens are masked and the slot mapping will be + # [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + padding_mask_len = max(0, start_idx - context_len) + slot_mapping.extend([PAD_SLOT_ID] * padding_mask_len) + + range_start = max(start_idx, context_len) + range_end = seq_len + numel = range_end - range_start + block_table = block_tables[seq_id] + + # numpy implementation will be faster than python if we have + # many elements, otherwise it will be slower. + if numel < _COMPUTE_SLOT_MAPPING_NUMPY_NUMEL: + _compute_slot_mapping_python(slot_mapping, block_table, range_start, + range_end, block_size) + else: + _compute_slot_mapping_numpy(slot_mapping, block_table, range_start, + range_end, block_size) + + +TAttentionMetadata = TypeVar("TAttentionMetadata", bound='AttentionMetadata') + + +class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): + + _metadata_cls: Type[TAttentionMetadata] + + def __init__(self, input_builder: "ModelInputForGPUBuilder"): + self.slot_mapping: List[int] = [] + self.prefill_seq_lens: List[int] = [] + self.context_lens: List[int] = [] + self.block_tables: List[List[int]] = [] + self.curr_seq_lens: List[int] = [] + self.num_prefills = 0 + self.num_prefill_tokens = 0 + self.num_decode_tokens = 0 + + self.input_builder = input_builder + self.runner = input_builder.runner + + self.sliding_window = input_builder.sliding_window + self.block_size = input_builder.block_size + self.use_v2_block_manager = ( + input_builder.scheduler_config.use_v2_block_manager) + + def _add_seq_group( + self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup", + chunked_prefill_enabled: bool): + is_prompt = inter_data.is_prompt + block_tables = inter_data.block_tables + computed_block_nums = inter_data.computed_block_nums + + for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, + curr_sliding_window_block) in zip( + inter_data.seq_ids, [len(t) for t in inter_data.input_tokens], + inter_data.orig_seq_lens, inter_data.seq_lens, + inter_data.query_lens, inter_data.context_lens, + inter_data.curr_sliding_window_blocks): + self.context_lens.append(context_len) + if is_prompt: + self.num_prefills += 1 + self.num_prefill_tokens += token_len + self.prefill_seq_lens.append(seq_len) + else: + assert query_len == 1, ( + "seq_len: {}, context_len: {}, query_len: {}".format( + seq_len, context_len, query_len)) + self.num_decode_tokens += query_len + self.curr_seq_lens.append(curr_seq_len) + + # Compute block table. + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + block_table = [] + if inter_data.prefix_cache_hit: + block_table = computed_block_nums + elif ((chunked_prefill_enabled or not is_prompt) + and block_tables is not None): + block_table = block_tables[seq_id][-curr_sliding_window_block:] + self.block_tables.append(block_table) + + # Compute slot mapping. + is_profile_run = is_block_tables_empty(block_tables) + start_idx = compute_slot_mapping_start_idx( + is_prompt, query_len, context_len, self.sliding_window, + self.use_v2_block_manager) + compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id, + seq_len, context_len, start_idx, + self.block_size, inter_data.block_tables) + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int): + """Build attention metadata with on-device tensors. + + Args: + seq_lens: The maybe padded sequence lengths of the input sequences. + query_lens: The query lengths of the input sequences. + cuda_graph_pad_size: The padding size for cuda graph. + -1 if cuda graph is not used. + batch_size: The maybe padded batch size. + """ + for inter_data in self.input_builder.inter_data_list: + self._add_seq_group(inter_data, + self.input_builder.chunked_prefill_enabled) + + device = self.runner.device + use_captured_graph = cuda_graph_pad_size != -1 + + max_query_len = max(query_lens) + max_prefill_seq_len = max(self.prefill_seq_lens, default=0) + max_decode_seq_len = max(self.curr_seq_lens, default=0) + num_decode_tokens = self.num_decode_tokens + + if use_captured_graph: + self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) + self.block_tables.extend([] * cuda_graph_pad_size) + num_decode_tokens = batch_size + + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.runner.graph_block_tables[:batch_size] + for i, block_table in enumerate(self.block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.from_numpy(input_block_tables).to( + device, non_blocking=True) + else: + block_tables = make_tensor_with_pad( + self.block_tables, + pad=0, + dtype=torch.int, + device=device, + ) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + assert device is not None + context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int, + device, self.runner.pin_memory) + seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, + self.runner.pin_memory) + query_lens_tensor = async_tensor_h2d(query_lens, torch.long, device, + self.runner.pin_memory) + slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long, + device, self.runner.pin_memory) + query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device=device) + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + torch.cumsum(query_lens_tensor, + dim=0, + dtype=query_start_loc.dtype, + out=query_start_loc[1:]) + + return self._metadata_cls( # type: ignore + num_prefills=self.num_prefills, + slot_mapping=slot_mapping_tensor, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=num_decode_tokens, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_prefill_seq_len=max_prefill_seq_len, + max_decode_seq_len=max_decode_seq_len, + query_start_loc=query_start_loc, + seq_start_loc=seq_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + ) + + +class CommonAttentionState(AttentionState): + + def __init__(self, runner: "ModelRunnerBase"): + self.runner = runner + self._is_graph_capturing = False + + @contextmanager + def graph_capture(self, max_batch_size: int): + self._is_graph_capturing = True + self._graph_slot_mapping = torch.full((max_batch_size, ), + PAD_SLOT_ID, + dtype=torch.long, + device=self.runner.device) + self._graph_seq_lens = torch.ones(max_batch_size, + dtype=torch.int32, + device=self.runner.device) + self._graph_block_tables = torch.from_numpy( + self.runner.graph_block_tables).to(device=self.runner.device) + yield + self._is_graph_capturing = False + del self._graph_slot_mapping + del self._graph_seq_lens + del self._graph_block_tables + + def graph_clone(self, batch_size: int) -> "CommonAttentionState": + assert self._is_graph_capturing + return self.__class__(self.runner) + + def graph_capture_get_metadata_for_batch( + self, batch_size: int, is_encoder_decoder_model: bool = False): + assert self._is_graph_capturing + attn_metadata = self.runner.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=self._graph_slot_mapping[:batch_size], + seq_lens=None, + seq_lens_tensor=self._graph_seq_lens[:batch_size], + max_query_len=1, + max_decode_query_len=1, + max_prefill_seq_len=0, + max_decode_seq_len=self.runner.max_seq_len_to_capture, + query_start_loc=None, + seq_start_loc=None, + context_lens_tensor=None, + block_tables=self._graph_block_tables[:batch_size], + use_cuda_graph=True, + ) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._update_captured_metadata_for_enc_dec_model( + batch_size=batch_size, attn_metadata=attn_metadata) + + return attn_metadata + + def get_graph_input_buffers( + self, + attn_metadata, + is_encoder_decoder_model: bool = False) -> Dict[str, Any]: + input_buffers = { + "slot_mapping": attn_metadata.slot_mapping, + "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, + "block_tables": attn_metadata.decode_metadata.block_tables, + } + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._add_additonal_input_buffers_for_enc_dec_model( + attn_metadata=attn_metadata, input_buffers=input_buffers) + return input_buffers + + def prepare_graph_input_buffers( + self, + input_buffers, + attn_metadata, + is_encoder_decoder_model: bool = False) -> None: + input_buffers["seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) + input_buffers["block_tables"].copy_( + attn_metadata.decode_metadata.block_tables, non_blocking=True) + if is_encoder_decoder_model: + # The encoder decoder model works only with XFormers backend. + # Assert the same. + assert self.runner.attn_backend.get_name() == "xformers", \ + f"Expected attn_backend name to be 'xformers', but "\ + f" got '{self.runner.attn_backend.get_name()}'" + self._prepare_input_buffers_for_enc_dec_model( + attn_metadata, input_buffers) + + def begin_forward(self, model_input) -> None: + return + + def _update_captured_metadata_for_enc_dec_model(self, batch_size: int, + attn_metadata): + """ + Updates the attention metadata parameters for CUDA graph capture in an + encoder-decoder model. + + This method modifies attention-related tensors and metadata required + for CUDA graph capture in encoder-decoder models. Specifically, it + updates the cross-attention and encoder sequence tensors in the + AttentionMetadata object. + """ + # During decode phase the cross_slot_mapping will be empty. Hence set + # an empty tensor for CUDA Graph capture. + attn_metadata.cross_slot_mapping = torch.tensor( + [], dtype=torch.int).cuda() + attn_metadata.cross_block_tables = torch.full( + (batch_size, self.runner.get_max_block_per_batch()), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens = torch.full((batch_size, ), + 1, + dtype=torch.int).cuda() + attn_metadata.encoder_seq_lens_tensor = torch.full( + (batch_size, ), 1, dtype=torch.int).cuda() + attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture + + def _add_additonal_input_buffers_for_enc_dec_model( + self, attn_metadata, input_buffers: Dict[str, Any]): + """ + Saves additional input buffers specific to the encoder-decoder model + from the attention metadata. + + This method extracts and stores encoder-decoder related input buffers + from the `attn_metadata` into the `input_buffers` dictionary. The + buffers include encoder sequence lengths, cross-slot mappings, and + cross-block tables, which are essential for the encoder-decoder model + during CUDA graph replay. + """ + input_buffers["encoder_seq_lens_tensor"] = ( + attn_metadata.decode_metadata.encoder_seq_lens_tensor) + input_buffers["cross_slot_mapping"] = ( + attn_metadata.decode_metadata.cross_slot_mapping) + input_buffers["cross_block_tables"] = ( + attn_metadata.decode_metadata.cross_block_tables) + + def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata, + input_buffers: Dict[str, + Any]): + """ + Populates input buffers with data from the encoder-decoder model's + attention metadata. + + This method fills the input buffers with encoder-decoder specific + tensors. It copies data from the `attn_metadata` and keyword arguments + (`kwargs`) into corresponding buffers in the `input_buffers` dictionary. + The copied data includes attention-related metadata as well as input + IDs and positional information for the encoder. + """ + input_buffers["encoder_seq_lens_tensor"].copy_( + attn_metadata.decode_metadata.encoder_seq_lens_tensor, + non_blocking=True) + input_buffers["cross_slot_mapping"].copy_( + attn_metadata.decode_metadata.cross_slot_mapping, + non_blocking=True) + input_buffers["cross_block_tables"].copy_( + attn_metadata.decode_metadata.cross_block_tables, + non_blocking=True) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py new file mode 100644 index 0000000..7c7dbfa --- /dev/null +++ b/vllm/attention/backends/xformers.py @@ -0,0 +1,822 @@ +"""Attention layer with xFormers and PagedAttention.""" +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type + +import torch +# from xformers import ops as xops +from ixformer.contrib.xformers import ops as xops +from xformers.ops.fmha.attn_bias import (AttentionBias, + BlockDiagonalMask,) +from ixformer.contrib.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) + +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionMetadata, AttentionType) +from vllm.attention.backends.utils import (CommonAttentionState, + CommonMetadataBuilder) +from vllm.attention.ops.paged_attn import (PagedAttention, + PagedAttentionMetadata) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class XFormersBackend(AttentionBackend): + + @staticmethod + def get_name() -> str: + return "xformers" + + @staticmethod + def get_impl_cls() -> Type["XFormersImpl"]: + return XFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + return XFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["XFormersMetadataBuilder"]: + return XFormersMetadataBuilder + + @staticmethod + def get_state_cls() -> Type["CommonAttentionState"]: + return CommonAttentionState + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return PagedAttention.get_kv_cache_shape(num_blocks, block_size, + num_kv_heads, head_size) + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + ) -> None: + PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + PagedAttention.copy_blocks(kv_caches, src_to_dists) + + +@dataclass +class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): + """Metadata for XFormersbackend. + + NOTE: Any python object stored here is not updated when it is + cuda-graph replayed. If you have values that need to be changed + dynamically, it should be stored in tensor. The tensor has to be + updated from `CUDAGraphRunner.forward` API. + """ + + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # seq_lens stored as a tensor. + seq_lens_tensor: Optional[torch.Tensor] + + # FIXME: It is for flash attn. + # Maximum sequence length among prefill batch. 0 if there are decoding + # requests only. + max_prefill_seq_len: int + # Maximum sequence length among decode batch. 0 if there are prefill + # requests only. + max_decode_seq_len: int + + # Whether or not if cuda graph is enabled. + # Cuda-graph is currently enabled for decoding only. + # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. + use_cuda_graph: bool + + # (batch_size,). The sequence length per sequence. Sequence length means + # the computed tokens + new tokens None if it is a decoding. + seq_lens: Optional[List[int]] = None + + # FIXME: It is for flash attn. + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] = None + + # (batch_size,) A tensor of context lengths (tokens that are computed + # so far). + context_lens_tensor: Optional[torch.Tensor] = None + + # Maximum query length in the batch. None for decoding. + max_query_len: Optional[int] = None + + # Max number of query tokens among request in the batch. + max_decode_query_len: Optional[int] = None + + # (batch_size + 1,). The cumulative subquery lengths of the sequences in + # the batch, used to index into subquery. E.g., if the subquery length + # is [4, 6], it is [0, 4, 10]. + query_start_loc: Optional[torch.Tensor] = None + + # Self-attention prefill/decode metadata cache + _cached_prefill_metadata: Optional["XFormersMetadata"] = None + _cached_decode_metadata: Optional["XFormersMetadata"] = None + + # Begin encoder attn & enc/dec cross-attn fields... + + # Encoder sequence lengths representation + encoder_seq_lens: Optional[List[int]] = None + encoder_seq_lens_tensor: Optional[torch.Tensor] = None + + # Maximum sequence length among encoder sequences + max_encoder_seq_len: Optional[int] = None + + # Number of tokens input to encoder + num_encoder_tokens: Optional[int] = None + + # Cross-attention memory-mapping data structures: slot mapping + # and block tables + cross_slot_mapping: Optional[torch.Tensor] = None + cross_block_tables: Optional[torch.Tensor] = None + + def __post_init__(self): + # Set during the execution of the first attention op. + # It is a list because it is needed to set per prompt + # when alibi slopes is used. It is because of the limitation + # from xformer API. + # will not appear in the __repr__ and __init__ + self.attn_bias: Optional[List[AttentionBias]] = None + self.encoder_attn_bias: Optional[List[AttentionBias]] = None + self.cross_attn_bias: Optional[List[AttentionBias]] = None + + @property + def is_all_encoder_attn_metadata_set(self): + ''' + All attention metadata required for encoder attention is set. + ''' + return ((self.encoder_seq_lens is not None) + and (self.encoder_seq_lens_tensor is not None) + and (self.max_encoder_seq_len is not None)) + + @property + def is_all_cross_attn_metadata_set(self): + ''' + All attention metadata required for enc/dec cross-attention is set. + + Superset of encoder attention required metadata. + ''' + return (self.is_all_encoder_attn_metadata_set + and (self.cross_slot_mapping is not None) + and (self.cross_block_tables is not None)) + + @property + def prefill_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_prefills == 0: + return None + + if self._cached_prefill_metadata is not None: + # Recover cached prefill-phase attention + # metadata structure + return self._cached_prefill_metadata + + assert ((self.seq_lens is not None) + or (self.encoder_seq_lens is not None)) + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + query_start_loc = (None if self.query_start_loc is None else + self.query_start_loc[:self.num_prefills + 1]) + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[:self.num_prefill_tokens]) + seq_lens = (None if self.seq_lens is None else + self.seq_lens[:self.num_prefills]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[:self.num_prefills]) + context_lens_tensor = (None if self.context_lens_tensor is None else + self.context_lens_tensor[:self.num_prefills]) + block_tables = (None if self.block_tables is None else + self.block_tables[:self.num_prefills]) + + # Construct & cache prefill-phase attention metadata structure + self._cached_prefill_metadata = XFormersMetadata( + num_prefills=self.num_prefills, + num_prefill_tokens=self.num_prefill_tokens, + num_decode_tokens=0, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=self.max_query_len, + max_prefill_seq_len=self.max_prefill_seq_len, + max_decode_seq_len=0, + query_start_loc=query_start_loc, + context_lens_tensor=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_prefill_metadata + + @property + def decode_metadata(self) -> Optional["XFormersMetadata"]: + if self.num_decode_tokens == 0: + return None + + if self._cached_decode_metadata is not None: + # Recover cached decode-phase attention + # metadata structure + return self._cached_decode_metadata + assert ((self.seq_lens_tensor is not None) + or (self.encoder_seq_lens_tensor is not None)) + + # Compute some attn_metadata fields which default to None + slot_mapping = (None if self.slot_mapping is None else + self.slot_mapping[self.num_prefill_tokens:]) + seq_lens_tensor = (None if self.seq_lens_tensor is None else + self.seq_lens_tensor[self.num_prefills:]) + block_tables = (None if self.block_tables is None else + self.block_tables[self.num_prefills:]) + + # Construct & cache decode-phase attention metadata structure + self._cached_decode_metadata = XFormersMetadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=self.num_decode_tokens, + slot_mapping=slot_mapping, + seq_lens_tensor=seq_lens_tensor, + max_prefill_seq_len=0, + max_decode_seq_len=self.max_decode_seq_len, + block_tables=block_tables, + use_cuda_graph=self.use_cuda_graph, + # Begin encoder & cross attn fields below... + encoder_seq_lens=self.encoder_seq_lens, + encoder_seq_lens_tensor=self.encoder_seq_lens_tensor, + max_encoder_seq_len=self.max_encoder_seq_len, + cross_slot_mapping=self.cross_slot_mapping, + cross_block_tables=self.cross_block_tables) + return self._cached_decode_metadata + + +def _get_attn_bias( + attn_metadata: XFormersMetadata, + attn_type: AttentionType, +) -> Optional[AttentionBias]: + ''' + Extract appropriate attention bias from attention metadata + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + * Appropriate attention bias value given the attention type + ''' + + if attn_type == AttentionType.DECODER: + return attn_metadata.attn_bias + elif attn_type == AttentionType.ENCODER: + return attn_metadata.encoder_attn_bias + else: + # attn_type == AttentionType.ENCODER_DECODER + return attn_metadata.cross_attn_bias + + +def _set_attn_bias( + attn_metadata: XFormersMetadata, + attn_bias: List[Optional[AttentionBias]], + attn_type: AttentionType, +) -> None: + ''' + Update appropriate attention bias field of attention metadata, + according to attention type. + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention + * attn_bias: The desired attention bias value + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + ''' + + if attn_type == AttentionType.DECODER: + attn_metadata.attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER: + attn_metadata.encoder_attn_bias = attn_bias + elif attn_type == AttentionType.ENCODER_DECODER: + attn_metadata.cross_attn_bias = attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +def _get_seq_len_block_table_args( + attn_metadata: XFormersMetadata, + is_prompt: bool, + attn_type: AttentionType, +) -> tuple: + ''' + The particular choice of sequence-length- and block-table-related + attributes which should be extracted from attn_metadata is dependent + on the type of attention operation. + + Decoder attn -> select entirely decoder self-attention-related fields + Encoder/decoder cross-attn -> select encoder sequence lengths & + cross-attn block-tables fields + Encoder attn -> select encoder sequence lengths fields & no block tables + + Arguments: + + * attn_metadata: Attention metadata structure associated with attention op + * is_prompt: True if prefill, False otherwise + * attn_type: encoder attention, decoder self-attention, + encoder/decoder cross-attention + + Returns: + + * Appropriate sequence-lengths tensor + * Appropriate max sequence-length scalar + * Appropriate block tables (or None) + ''' + + if attn_type == AttentionType.DECODER: + # Decoder self-attention + # Choose max_seq_len based on whether we are in prompt_run + if is_prompt: + max_seq_len = attn_metadata.max_prefill_seq_len + else: + max_seq_len = attn_metadata.max_decode_seq_len + return (attn_metadata.seq_lens_tensor, max_seq_len, + attn_metadata.block_tables) + elif attn_type == AttentionType.ENCODER_DECODER: + # Enc/dec cross-attention KVs match encoder sequence length; + # cross-attention utilizes special "cross" block tables + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_block_tables) + elif attn_type == AttentionType.ENCODER: + # No block tables associated with encoder attention + return (attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, None) + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") + + +class XFormersMetadataBuilder(CommonMetadataBuilder[XFormersMetadata]): + + _metadata_cls = XFormersMetadata + + +class XFormersImpl(AttentionImpl[XFormersMetadata]): + """ + If the input tensors contain prompt tokens, the layout is as follows: + |<--------------- num_prefill_tokens ----------------->| + |<--prefill_0-->|<--prefill_1-->|...|<--prefill_N-1--->| + + Otherwise, the layout is as follows: + |<----------------- num_decode_tokens ------------------>| + |<--decode_0-->|..........|<--decode_M-1-->|<--padding-->| + + Generation tokens can contain padding when cuda-graph is used. + Currently, prompt tokens don't contain any padding. + + The prompts might have different lengths, while the generation tokens + always have length 1. + + If chunked prefill is enabled, prefill tokens and decode tokens can be + batched together in a flattened 1D query. + + |<----- num_prefill_tokens ---->|<------- num_decode_tokens --------->| + |<-prefill_0->|...|<-prefill_N-1->|<--decode_0-->|...|<--decode_M-1-->| + + Currently, cuda graph is disabled for chunked prefill, meaning there's no + padding between prefill and decode tokens. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[List[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + self.kv_cache_dtype = kv_cache_dtype + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + suppored_head_sizes = PagedAttention.get_supported_head_sizes() + if head_size not in suppored_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by PagedAttention. " + f"Supported head sizes are: {suppored_head_sizes}.") + self.head_mapping = torch.repeat_interleave( + torch.arange(self.num_kv_heads, dtype=torch.int32), + self.num_queries_per_kv) + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: "XFormersMetadata", + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Forward pass with xFormers and PagedAttention. + + For decoder-only models: query, key and value must be non-None. + + For encoder/decoder models: + * XFormersImpl.forward() may be invoked for both self- and cross- + attention layers. + * For self-attention: query, key and value must be non-None. + * For cross-attention: + * Query must be non-None + * During prefill, key and value must be non-None; key and value + get cached for use during decode. + * During decode, key and value may be None, since: + (1) key and value tensors were cached during prefill, and + (2) cross-attention key and value tensors do not grow during + decode + + A note on how the attn_type (attention type enum) argument impacts + attention forward() behavior: + + * DECODER: normal decoder-only behavior; + use decoder self-attention block table + * ENCODER: no KV caching; pass encoder sequence + attributes (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) to kernel, in lieu of decoder + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + * ENCODER_DECODER: cross-attention behavior; + use cross-attention block table for caching KVs derived + from encoder hidden states; since KV sequence lengths + will match encoder sequence lengths, pass encoder sequence + attributes to kernel (encoder_seq_lens/encoder_seq_lens_tensor/ + max_encoder_seq_len) + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size] + NOTE: kv_cache will be an empty tensor with shape [0] + for profiling run. + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + # Check that appropriate attention metadata attributes are + # selected for the desired attention type + if (attn_type == AttentionType.ENCODER + and (not attn_metadata.is_all_encoder_attn_metadata_set)): + raise AttributeError("Encoder attention requires setting " + "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER + and (not attn_metadata.is_all_cross_attn_metadata_set)): + raise AttributeError("Encoder/decoder cross-attention " + "requires setting cross-attention " + "metadata attributes.") + + query = query.view(-1, self.num_heads, self.head_size) + if key is not None: + assert value is not None + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + else: + assert value is None + + # Self-attention vs. cross-attention will impact + # which KV cache memory-mapping & which + # seqlen datastructures we utilize + + if (attn_type != AttentionType.ENCODER and kv_cache.numel() > 0): + # KV-cache during decoder-self- or + # encoder-decoder-cross-attention, but not + # during encoder attention. + # + # Even if there are no new key/value pairs to cache, + # we still need to break out key_cache and value_cache + # i.e. for later use by paged attention + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_size) + + if (key is not None) and (value is not None): + + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + # During cross-attention decode, key & value will be None, + # preventing this IF-statement branch from running + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + PagedAttention.write_to_paged_cache(key, value, key_cache, + value_cache, + updated_slot_mapping, + self.kv_cache_dtype, + k_scale, v_scale) + + if attn_type == AttentionType.ENCODER: + # Encoder attention - chunked prefill is not applicable; + # derive token-count from query shape & and treat them + # as 100% prefill tokens + assert attn_metadata.num_encoder_tokens is not None + num_prefill_tokens = attn_metadata.num_encoder_tokens + num_encoder_tokens = attn_metadata.num_encoder_tokens + num_decode_tokens = 0 + elif attn_type == AttentionType.DECODER: + # Decoder self-attention supports chunked prefill. + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + # Only enforce this shape-constraint for decoder + # self-attention + assert key.shape[0] == num_prefill_tokens + num_decode_tokens + assert value.shape[0] == num_prefill_tokens + num_decode_tokens + else: # attn_type == AttentionType.ENCODER_DECODER + # Encoder/decoder cross-attention requires no chunked + # prefill (100% prefill or 100% decode tokens, no mix) + num_prefill_tokens = attn_metadata.num_prefill_tokens + if attn_metadata.num_encoder_tokens is not None: + num_encoder_tokens = attn_metadata.num_encoder_tokens + else: + num_encoder_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + output = torch.empty_like(query) + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + if key is not None and value is not None: + key = key[:num_encoder_tokens] + value = value[:num_encoder_tokens] + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0: + # normal attention. + # block tables are empty if the prompt does not have a cached + # prefix. + out = self._run_memory_efficient_xformers_forward( + query, key, value, prefill_meta, attn_type=attn_type) + assert out.shape == output[:num_prefill_tokens].shape + output[:num_prefill_tokens] = out + else: + + assert prefill_meta.query_start_loc is not None + assert prefill_meta.max_query_len is not None + + # prefix-enabled attention + # TODO(Hai) this triton kernel has regression issue (broke) to + # deal with different data types between KV and FP8 KV cache, + # to be addressed separately. + out = PagedAttention.forward_prefix( + query, + key, + value, + self.kv_cache_dtype, + key_cache, + value_cache, + prefill_meta.block_tables, + prefill_meta.query_start_loc, + prefill_meta.seq_lens_tensor, + prefill_meta.context_lens_tensor, + prefill_meta.max_query_len, + self.alibi_slopes, + self.sliding_window, + k_scale, + v_scale, + ) + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out + + if decode_meta := attn_metadata.decode_metadata: + + ( + seq_lens_arg, + max_seq_len_arg, + block_tables_arg, + ) = _get_seq_len_block_table_args(decode_meta, False, attn_type) + + output[num_prefill_tokens:] = PagedAttention.forward_decode( + decode_query, + key_cache, + value_cache, + block_tables_arg, + seq_lens_arg, + max_seq_len_arg, + self.kv_cache_dtype, + self.head_mapping, + self.scale, + self.alibi_slopes, + k_scale, + v_scale, + ) + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + def _run_memory_efficient_xformers_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: XFormersMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + """Attention for 1D query of multiple prompts. Multiple prompt + tokens are flattened in to `query` input. + + See https://facebookresearch.github.io/xformers/components/ops.html + for API spec. + + Args: + output: shape = [num_prefill_tokens, num_heads, head_size] + query: shape = [num_prefill_tokens, num_heads, head_size] + key: shape = [num_prefill_tokens, num_kv_heads, head_size] + value: shape = [num_prefill_tokens, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + attn_type: Select attention type, between encoder attention, + decoder self-attention, or encoder/decoder cross- + attention. Defaults to decoder self-attention, + which is the vLLM default generally + """ + + original_query = query + # if self.num_kv_heads != self.num_heads: + # # GQA/MQA requires the shape [B, M, G, H, K]. + # # Note that the output also has the same shape (which is different + # # from a spec from the doc). + # query = query.view(query.shape[0], self.num_kv_heads, + # self.num_queries_per_kv, query.shape[-1]) + # print(f"5555555555555 q shape {query.shape}") + # key = key[:, :, + # None, :].expand(key.shape[0], self.num_kv_heads, + # self.num_queries_per_kv, key.shape[-1]) + # value = value[:, :, + # None, :].expand(value.shape[0], self.num_kv_heads, + # self.num_queries_per_kv, + # value.shape[-1]) + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + attn_bias = _get_attn_bias(attn_metadata, attn_type) + if attn_bias is None: + if self.alibi_slopes is None: + if (attn_type == AttentionType.ENCODER_DECODER): + assert attn_metadata.seq_lens is not None + assert attn_metadata.encoder_seq_lens is not None + + # Default enc/dec cross-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + elif attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None + + # Default encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.encoder_seq_lens) + else: + assert attn_metadata.seq_lens is not None + + # Default decoder self-attention mask is causal + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + attn_bias = [attn_bias] + else: + assert attn_metadata.seq_lens is not None + attn_bias = _make_alibi_bias(self.alibi_slopes, + self.num_kv_heads, query.dtype, + attn_metadata.seq_lens) + + _set_attn_bias(attn_metadata, attn_bias, attn_type) + + # No alibi slopes. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + self.attn_op = xops.fmha.flash.FwOp() + if self.alibi_slopes is None: + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=attn_bias[0], + p=0.0, + scale=self.scale, + op = self.attn_op + ) + return out.view_as(original_query) + + # Attention with alibi slopes. + # FIXME(woosuk): Because xformers does not support dynamic sequence + # lengths with custom attention bias, we process each prompt one by + # one. This is inefficient, especially when we have many short prompts. + assert attn_metadata.seq_lens is not None + output = torch.empty_like(original_query) + start = 0 + for i, seq_len in enumerate(attn_metadata.seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=self.scale, + ) + # TODO(woosuk): Unnecessary copy. Optimize. + output[start:end].copy_(out.view_as(original_query[start:end])) + start += seq_len + return output + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + dtype: torch.dtype, + seq_lens: List[int], +) -> List[AttentionBias]: + attn_biases: List[AttentionBias] = [] + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + # Calculate a matrix where each element represents ith element- jth + # element. + bias = bias[None, :] - bias[:, None] + + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + 1, # batch size + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_biases.append(LowerTriangularMaskWithTensorBias(bias)) + + return attn_biases \ No newline at end of file diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py new file mode 100644 index 0000000..5ac21e4 --- /dev/null +++ b/vllm/attention/layer.py @@ -0,0 +1,116 @@ +"""Attention layer.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +from vllm.attention import AttentionMetadata, AttentionType +from vllm.attention.selector import get_attn_backend +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + prefix: str = "", + use_sqrt_alibi: Optional[bool] = False + ) -> None: + super().__init__() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + block_size = cache_config.block_size + sliding_window = cache_config.sliding_window + is_attention_free = cache_config.is_attention_free + else: + kv_cache_dtype = "auto" + block_size = 16 + sliding_window = None + is_attention_free = False + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._k_scale = 1.0 + self._v_scale = 1.0 + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None: + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + # During model initialization, the default dtype is set as the model + # weight and activation dtype. + dtype = torch.get_default_dtype() + attn_backend = get_attn_backend(head_size, sliding_window, dtype, + kv_cache_dtype, block_size, + is_attention_free, blocksparse_params + is not None) + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + attn_type: AttentionType = AttentionType.DECODER, + ) -> torch.Tensor: + + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s diff --git a/vllm/attention/ops/__init__.py b/vllm/attention/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/ops/__pycache__/__init__.cpython-310.pyc b/vllm/attention/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..591955c Binary files /dev/null and b/vllm/attention/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/attention/ops/__pycache__/ipex_attn.cpython-310.pyc b/vllm/attention/ops/__pycache__/ipex_attn.cpython-310.pyc new file mode 100644 index 0000000..bb52c49 Binary files /dev/null and b/vllm/attention/ops/__pycache__/ipex_attn.cpython-310.pyc differ diff --git a/vllm/attention/ops/__pycache__/paged_attn.cpython-310.pyc b/vllm/attention/ops/__pycache__/paged_attn.cpython-310.pyc new file mode 100644 index 0000000..ddc8a86 Binary files /dev/null and b/vllm/attention/ops/__pycache__/paged_attn.cpython-310.pyc differ diff --git a/vllm/attention/ops/__pycache__/prefix_prefill.cpython-310.pyc b/vllm/attention/ops/__pycache__/prefix_prefill.cpython-310.pyc new file mode 100644 index 0000000..bb82b22 Binary files /dev/null and b/vllm/attention/ops/__pycache__/prefix_prefill.cpython-310.pyc differ diff --git a/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-310.pyc b/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-310.pyc new file mode 100644 index 0000000..33c30d5 Binary files /dev/null and b/vllm/attention/ops/__pycache__/triton_flash_attention.cpython-310.pyc differ diff --git a/vllm/attention/ops/blocksparse_attention/__init__.py b/vllm/attention/ops/blocksparse_attention/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-310.pyc b/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..ec2a31a Binary files /dev/null and b/vllm/attention/ops/blocksparse_attention/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-310.pyc b/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-310.pyc new file mode 100644 index 0000000..71fa12d Binary files /dev/null and b/vllm/attention/ops/blocksparse_attention/__pycache__/blocksparse_attention_kernel.cpython-310.pyc differ diff --git a/vllm/attention/ops/blocksparse_attention/__pycache__/interface.cpython-310.pyc b/vllm/attention/ops/blocksparse_attention/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000..e01d624 Binary files /dev/null and b/vllm/attention/ops/blocksparse_attention/__pycache__/interface.cpython-310.pyc differ diff --git a/vllm/attention/ops/blocksparse_attention/__pycache__/utils.cpython-310.pyc b/vllm/attention/ops/blocksparse_attention/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..306f346 Binary files /dev/null and b/vllm/attention/ops/blocksparse_attention/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py new file mode 100644 index 0000000..ec1c37c --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -0,0 +1,423 @@ +import torch +import triton +import triton.language as tl + + +def blocksparse_flash_attn_varlen_fwd( + q, + k, + v, # (#tokens, n_heads, head_size) + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + sparse_layout, + *, + block_size=64, + q_block_size=None, + max_seqlen=None): + # split q to blocks + + assert isinstance(sparse_layout, (list, tuple)) + + _, n_heads, head_size = q.shape + batch_size = cu_seqlens_k.size(0) - 1 + q_block_size = q_block_size or block_size + + assert q.dim() == k.dim() == v.dim() == 3 + assert q.size(1) % k.size(1) == 0 + assert q.size(2) == k.size(2) + # TODO(linxihui): allow k, v to have different head_size + assert k.shape == v.shape + assert cu_seqlens_k.dim() == 1 + + q_k_ratio = q.size(1) // k.size(1) + + if cu_seqlens_q is None: + if q.size(0) == batch_size: # decoding only + cu_seqlens_q = torch.arange( + 0, + batch_size + 1, + dtype=cu_seqlens_k.dtype, + device=cu_seqlens_k.device, + ) + elif q.size(0) == k.size(0): + cu_seqlens_q = cu_seqlens_k + else: + raise ValueError("cu_seqlens_q must be specified\ + if it mix of prefilling and decoding.") + else: + assert cu_seqlens_k.size(0) == cu_seqlens_q.size(0) + + # switch to use cpu to avoid too many kernel launches when iterated over + q_lens = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).cpu() + k_lens = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).cpu() + + assert torch.logical_or(q_lens == 1, k_lens == q_lens).all(), ( + "length of q should either be 1 (decoding) or same as k (prefilling).") + + if max_seqlen: + assert k_lens.max() <= max_seqlen + + n_blocks = (q_lens + q_block_size - 1) // q_block_size + + q_batch_ids = torch.tensor( + [i for i, n in enumerate(n_blocks) for _ in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + q_start_sids = torch.tensor( + [i * q_block_size for n in n_blocks for i in range(n)], + dtype=cu_seqlens_q.dtype, + device=cu_seqlens_q.device, + ) + + out = q.new_empty(q.shape) + cu_seqlens_q = cu_seqlens_q.contiguous() + cu_seqlens_k = cu_seqlens_k.contiguous() + + layout_crow_indices, layout_col_indices = sparse_layout + block_d = triton.next_power_of_2(head_size) + + decoding_only = (q_lens == 1).all().item() + grid = (len(q_start_sids), n_heads, 1) + + _fwd_kernel_batch_inference[grid]( + q, + k, + v, + out, + sm_scale, + cu_seqlens_q[:-1], + cu_seqlens_q[1:], + cu_seqlens_k[:-1], + cu_seqlens_k[1:], + q_batch_ids, + q_start_sids, + 0, + *q.stride(), + 0, + *k.stride(), + 0, + *v.stride(), + 0, + *out.stride(), + layout_crow_indices, + layout_col_indices, + *layout_crow_indices.stride(), + *layout_col_indices.stride(), + q_k_ratio, + HAS_BATCH_DIM=False, + D_HEAD=head_size, + BLOCK_M=q_block_size, + BLOCK_N=block_size, + BLOCK_D=block_d, + BLOCK_M_LOADING=(16 if decoding_only else + q_block_size), # smaller for decoding + EVEN_D=block_d == head_size, + num_warps=1 if decoding_only else 4, + num_stages=3) + + return out + + +@triton.jit +def _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + LAST_K_BLOCK: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + BLOCK_N: tl.constexpr, + D_HEAD: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + k_block_id = tl.load(layout_col_ptr + off_h * layout_col_stride_h + + k_block_col_idx * layout_col_stride_m).to(tl.int32) + start_n = k_block_id * BLOCK_N + if LAST_K_BLOCK: + if EVEN_D: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=offs_n[None, :] + start_n < k_seqlen, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kt, + mask=(offs_n[None, :] + start_n < k_seqlen) & + (offs_d[:, None] < D_HEAD), + ) + else: + if EVEN_D: + k = tl.load(k_ptrs + start_n * stride_kt) + else: + k = tl.load(k_ptrs + start_n * stride_kt, + mask=offs_d[:, None] < D_HEAD) + + qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + # the following is needed only when LAST_K_BLOCK or BLOCK_M < BLOCK_N + if LAST_K_BLOCK | M_LT_N: + qk += tl.where( + offs_m[:, None] + past_len >= (start_n + offs_n[None, :]), + 0, + float("-inf"), + ) + + # flash-attn2 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + # update m_i + m_i = m_ij + l_i = l_i * alpha + l_ij + + p = p.to(Q.dtype.element_ty) + # update acc + if LAST_K_BLOCK: + if EVEN_D: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=offs_n[:, None] + start_n < k_seqlen, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vt, + mask=(offs_n[:, None] + start_n < k_seqlen) & + (offs_d[None, :] < D_HEAD), + ) + else: + if EVEN_D: + v = tl.load(v_ptrs + start_n * stride_vt) + else: + v = tl.load(v_ptrs + start_n * stride_vt, + mask=offs_d[None, :] < D_HEAD) + + acc += tl.dot(p, v) + + return acc, l_i, m_i + + +@triton.heuristics({ + "M_LT_N": + lambda kwargs: kwargs["BLOCK_M"] < kwargs["BLOCK_N"], +}) +@triton.jit +def _fwd_kernel_batch_inference( + Q, + K, + V, + Out, + sm_scale, + q_batch_starts, + q_batch_ends, + k_batch_starts, + k_batch_ends, + q_batch_ids, + q_start_sids, + stride_qb, + stride_qt, + stride_qh, + stride_qd, + stride_kb, + stride_kt, + stride_kh, + stride_kd, + stride_vb, + stride_vt, + stride_vh, + stride_vd, + stride_ob, + stride_ot, + stride_oh, + stride_od, + layout_crow_ptr, + layout_col_ptr, + layout_crow_stride_h, + layout_crow_stride_m, + layout_col_stride_h, + layout_col_stride_m, + q_k_ratio, + HAS_BATCH_DIM: tl.constexpr, + D_HEAD: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_M_LOADING: tl.constexpr, + EVEN_D: tl.constexpr, + M_LT_N: tl.constexpr, +): + """ + NOTATION: + pid: position id + sid: storage id + sbid: storage block id + pbid: position block id + offs_m, offs_n: storage offsets of m-dim(q, row) and n-dim(k, col) + + TODO(linxihui): + Optimize grouped-attn + """ + off_zm = tl.program_id(0) + off_h = tl.program_id(1) + + off_h_for_kv = off_h // q_k_ratio + + if HAS_BATCH_DIM: + off_z = tl.program_id(2) + Q += off_z * stride_qb + K += off_z * stride_kb + V += off_z * stride_vb + Out += off_z * stride_ob + start_m = off_zm + q_start_sid = start_m * BLOCK_M # always 0 for decoding + else: + off_z = tl.load(q_batch_ids + off_zm).to(tl.int32) # [0, 0, 0, 1] + q_start_sid = tl.load(q_start_sids + off_zm) + start_m = q_start_sid // BLOCK_M # q_sbid + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M_LOADING) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_D) + + q_cu_start = tl.load(q_batch_starts + off_z).to(tl.int32) + q_seqlen = tl.load(q_batch_ends + off_z).to(tl.int32) - q_cu_start + k_cu_start = tl.load(k_batch_starts + off_z).to(tl.int32) + k_seqlen = tl.load(k_batch_ends + off_z).to(tl.int32) - k_cu_start + past_len = k_seqlen - q_seqlen + + Q += q_cu_start * stride_qt + off_h * stride_qh + K += k_cu_start * stride_kt + off_h_for_kv * stride_kh + V += k_cu_start * stride_vt + off_h_for_kv * stride_vh + Out += q_cu_start * stride_ot + off_h * stride_oh + + q_pbid = (past_len + q_start_sid) // BLOCK_M + + if EVEN_D: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=offs_m[:, None] < q_seqlen, + ) + else: + q = tl.load( + Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + other=0, + ) + + sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + + q_pbid * layout_crow_stride_m) + + # TODO(linxihui): load at once, with any Triton version + # that supports `tl.split`, e.g., Triton 3.0 + k_block_start = tl.load(sparse_crow_ptr).to(tl.int32) + k_block_end = tl.load(sparse_crow_ptr + 1).to(tl.int32) + + m_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M_LOADING], dtype=tl.float32) + acc = tl.zeros([BLOCK_M_LOADING, BLOCK_D], dtype=tl.float32) + + k_ptrs = K + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd + + sm_scale *= ( + 1.44269504 # 1/log2 as we use base2 for exponential and logarithm + ) + + for k_block_col_idx in range(k_block_start, k_block_end - 1): + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_col_idx, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + False, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + acc, l_i, m_i = _fwd_kernel_inner( + acc, + l_i, + m_i, + q, + Q, + k_block_end - 1, + layout_col_ptr, + layout_col_stride_h, + layout_col_stride_m, + k_ptrs, + v_ptrs, + off_h, + offs_m, + offs_n, + offs_d, + stride_kt, + stride_vt, + sm_scale, + k_seqlen, + past_len, + True, + BLOCK_M_LOADING, + BLOCK_N, + D_HEAD, + EVEN_D, + M_LT_N, + ) + + # flash-attn 2 + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + + # write output + if EVEN_D: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=offs_m[:, None] < q_seqlen, + ) + else: + tl.store( + Out + offs_m[:, None] * stride_ot + offs_d[None, :] * stride_od, + acc, + mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), + ) diff --git a/vllm/attention/ops/blocksparse_attention/interface.py b/vllm/attention/ops/blocksparse_attention/interface.py new file mode 100644 index 0000000..1ead541 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/interface.py @@ -0,0 +1,238 @@ +import math + +import torch + +from vllm.platforms import current_platform +from vllm.utils import is_cpu, is_hip + +from .utils import (dense_to_crow_col, get_head_sliding_step, + get_sparse_attn_mask) + +IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80) + +if IS_COMPUTE_8_OR_ABOVE: + from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd + + +class LocalStridedBlockSparseAttn(torch.nn.Module): + + def __init__( + self, + n_heads, + max_seqlen, + local_blocks, + vert_stride, + block_size, + device=None, + dtype=None, + homo_head=False, + active_head_range=None, + q_block_size=None, + use_spda=None, + ): + super().__init__() + if use_spda is None: + use_spda = is_hip() or is_cpu() or not \ + IS_COMPUTE_8_OR_ABOVE + device = device or (torch.cuda.current_device() + if current_platform.is_cuda_alike() else "cpu") + device = torch.device(device) + # NOTE: vllm CPU backend support BF16 instead of FP16. + dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE + or device.type == "cpu" else torch.half) + + self.n_heads = n_heads + self.max_seqlen = max_seqlen + self.local_blocks = local_blocks + self.vert_stride = vert_stride + self.use_spda = use_spda + self.dtype = dtype + self.device = device + self.block_size = block_size + self.q_block_size = q_block_size + self.homo_head = homo_head + self.active_head_range = active_head_range + self.head_sliding_step = get_head_sliding_step(n_heads, vert_stride, + homo_head) + + sparse_layout, sparse_pattern, self.dense_attn_mask = ( + self.get_attn_pattern(dtype, device)) + + if q_block_size is not None and q_block_size != block_size: + if q_block_size > block_size: + assert q_block_size % block_size == 0 + blocks_to_merge = q_block_size // block_size + shape = sparse_pattern.shape + sparse_pattern = sparse_pattern.view(shape[0], -1, + blocks_to_merge, + shape[-1]) + sparse_pattern = sparse_pattern.sum(2) + sparse_layout = dense_to_crow_col(sparse_pattern) + else: + raise ValueError( + "Does not support smaller q_block_size. It will be slower." + ) + + self.sparse_layout = sparse_layout + + def get_attn_pattern(self, dtype, device): + sparse_layout, sparse_pattern, dense_attn_mask = get_sparse_attn_mask( + self.n_heads, + self.max_seqlen, + self.max_seqlen, + dtype, + device, + block_size=self.block_size, + local_blocks=self.local_blocks, + vert_stride=self.vert_stride, + homo_head=self.homo_head, + return_dense=self.use_spda, + dense_mask_type="bias", + ) + if (not self.homo_head) and (self.active_head_range is not None): + assert isinstance(self.active_head_range, tuple) + assert (len(self.active_head_range) == 2) + h_start, h_end = self.active_head_range + sparse_layout = tuple(x[h_start:h_end] for x in sparse_layout) + if self.use_spda: + dense_attn_mask = dense_attn_mask[h_start:h_end] + return sparse_layout, sparse_pattern, dense_attn_mask + + def varlen_attn(self, + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=None, + sm_scale=None): + """ + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), + indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify is when q is a mix of + prefilling and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert ( + IS_COMPUTE_8_OR_ABOVE + ), "Requires compute capability of 8 or above (Ampere or newer) to use \ + Triton kernel." + + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + + return blocksparse_flash_attn_varlen_fwd( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q, + sm_scale, + self.sparse_layout, + block_size=self.block_size, + q_block_size=self.q_block_size, + max_seqlen=self.max_seqlen, + ) + + @staticmethod + def transpose_and_pad(x, cu_seqlens, maxlen, head_repeats=1): + """ + :param x: (total_tokens, n_heads, head_size) + :return: (batch, n_heads, length, head_size) + """ + x_padded = x.new_empty( + len(cu_seqlens) - 1, x.size(1), head_repeats, maxlen, x.size(2)) + cu_seqlens = cu_seqlens.cpu() + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x_padded[i, :, :, :e - s].copy_(x[s:e].transpose(0, + 1).unsqueeze(1)) + return x_padded.flatten(1, 2) + + @staticmethod + def transpose_and_unpad(x_padded, cu_seqlens): + """ + :param x_padded: (batch, n_heads, length, head_size) + :return: (total_tokens, n_heads, head_size) + """ + cu_seqlens = cu_seqlens.cpu() + total_n_tokens = cu_seqlens[-1] + x = x_padded.new_empty(total_n_tokens, x_padded.size(1), + x_padded.size(3)) + for i, (s, e) in enumerate(zip(cu_seqlens[:-1], cu_seqlens[1:])): + x[s:e].copy_(x_padded[i, :, :e - s].transpose(0, 1)) + return x + + def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """For CPU, V100 or other older GPUs. + NOTE: torch SPDA supports nested tensor, + but seems extremely slow. Choose to pad instead. + """ + assert (cu_seqlens_q is None or + (cu_seqlens_q + == cu_seqlens_k).all()), "Can only handle prompt with SPDA." + assert q.size(0) == k.size(0), "can only handle prompt with SPDA." + + assert q.size(1) % k.size(1) == 0 + q_k_ratio = q.size(1) // k.size(1) + sm_scale = sm_scale or 1.0 / math.sqrt(q.size(-1)) + cu_seqlens = cu_seqlens_k.cpu() + maxlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + if (self.dense_attn_mask.dtype != q.dtype + or self.dense_attn_mask.device != q.device): + _, _, self.dense_attn_mask = self.get_attn_pattern( + q.dtype, q.device) + attn_mask = self.dense_attn_mask[None, :, :maxlen, :maxlen] + + q2 = self.transpose_and_pad(q, cu_seqlens, maxlen, 1) + k2, v2 = [ + self.transpose_and_pad(x, cu_seqlens, maxlen, q_k_ratio) + for x in [k, v] + ] + spda_output = torch.nn.functional.scaled_dot_product_attention( + q2, k2, v2, attn_mask=attn_mask, scale=sm_scale) + return self.transpose_and_unpad(spda_output, cu_seqlens) + + def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None): + """Dispatch to `varlen_attn` (Ampere or newer) or + `self.spda`(cpu, Volta, Turing or older)based on + the type of device used and cuda compute capability. + + q, k, v: shape = (num_tokens, num_heads_q/kv, head_size). + Support grouped attention, with `q[:, i*r:(i*r + r)]` + is correspondent to `k[:, i]`, where `r` is the q/k ratio. + cu_seqlens_k: shape=(batch_size + 1,), indicating segment of samples, + e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i + cu_seqlens_q: shape=(batch_size + 1, ). + Default None: same as cu_seqlens_k for prefilling or + [0, 1, .., batch_size] for decoding. + The only case you need to specify + is when q is a mix of prefilling + and decoding. + sm_scale: softmax scale, default to 1/sqrt(head_size). + + return: tensor of shape as q. + """ + assert k.dim() == 3 + if self.use_spda: + return self.spda( + q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale, + ) + return self.varlen_attn(q, + k, + v, + cu_seqlens_k, + cu_seqlens_q=cu_seqlens_q, + sm_scale=sm_scale) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py new file mode 100644 index 0000000..78d7522 --- /dev/null +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -0,0 +1,242 @@ +# Helper functions for 3D sparse pattern +# These function are not optimized and very inefficient. +# Avoid calling them too frequent or use a cache mechanism. + +from functools import lru_cache + +import numpy as np +import torch +import triton + + +class csr_matrix: + """Simple implementation of CSR matrix conversion without scipy. + This replaced scipy.sparse.csr_matrix() previously used.""" + + def __init__(self, input_array): + if not isinstance(input_array, np.ndarray): + raise ValueError("Input must be a NumPy array") + + self.shape = input_array.shape + rows, cols = self.shape + data = [] + indices = [] + indptr = [0] + + for i in range(rows): + for j in range(cols): + if input_array[i, j]: + data.append(input_array[i, j]) + indices.append(j) + indptr.append(len(indices)) + + self.data = np.array(data) + self.indices = np.array(indices) + self.indptr = np.array(indptr) + + +def dense_to_crow_col(x: torch.Tensor): + """Turning a 2D/3D torch tensor (x) to CSR rows/cols indexing. + NOTE: col_indices padded -1 + """ + device = x.device + pad = -1 + dim = x.dim() + assert x.dim() in (2, 3) + if x.dim() == 2: + x = x[None] + x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] + crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) + cols = [torch.from_numpy(xi.indices) for xi in x] + max_cols = max(len(xi) for xi in cols) + cols = [ + torch.cat([xi, pad + xi.new_zeros(max_cols - xi.shape[0])]) + for xi in cols + ] + cols = torch.vstack(cols) + if dim == 2: + crows = crows[0] + cols = cols[0] + return crows.to(device), cols.to(device) + + +def crow_col_to_dense(crows: torch.Tensor, + cols: torch.Tensor, + dtype: torch.dtype = torch.float16): + dim = crows.dim() + if dim == 1: + crows = crows[None] + cols = cols[None] + device = crows.device + crows, cols = crows.cpu(), cols.cpu() # faster in cpu + shape = (crows.shape[0], crows.shape[1] - 1, cols.max() + 1) + x = torch.zeros(shape, dtype=dtype) + for i in range(shape[0]): + for j in range(shape[1]): + x[i, j, cols[i, crows[i, j]:crows[i, j + 1]]] = 1 + if dim == 1: + x = x[0] + return x.to(device) + + +def dense_to_ccol_row(x: torch.Tensor): + """Similar, but to CSC format""" + x = x.transpose(-2, -1) + return dense_to_crow_col(x) + + +def ccol_row_to_dense(ccol: torch.Tensor, + rows: torch.Tensor, + dtype: torch.dtype = torch.float16): + return crow_col_to_dense(ccol, rows, dtype).permute(0, 2, 1).contiguous() + + +def _get_sparse_attn_mask_homo_head( + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 128, + local_blocks: int = 4, + vert_stride: int = 4, + return_dense: bool = False, +): + """ + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be + OOM if it is too big) if `return_dense==True`, + otherwise, None + """ + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[:, None] + k_pos = torch.arange(num_blocks)[None] + mask_vert_strided = (torch.arange(num_blocks) + 1) % vert_stride == 0 + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = (dense_to_crow_col( + block_mask_dense[-num_blocks_q:].contiguous())) + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[-q_len:, :max_seqlen] * causal_mask + return ( + block_mask_dense_output, + block_mask_dense, + mask_dense, + ) + else: + return ( + block_mask_dense_output, + block_mask_dense, + None, + ) + + +def binary_mask_to_bias(mask_dense: torch.Tensor): + mask_dense = 1 - mask_dense + mask_dense.masked_fill_(mask_dense.bool(), -torch.inf) + return mask_dense + + +def get_head_sliding_step(n_heads: int, + vert_stride: int, + homo_head: bool = False): + if homo_head: + return 0 + return max(1, int(vert_stride / n_heads)) + + +@lru_cache +def get_sparse_attn_mask( + n_heads: int, + q_len: int, + max_seqlen: int, + dtype: torch.dtype, + device: torch.device, + block_size: int = 64, + local_blocks: int = 4, + vert_stride: int = 4, + homo_head: bool = True, + return_dense: bool = False, + dense_mask_type: str = "binary", +): + """ + :param dense_mask_type: "binary" (0 for skip token, 1 for others) + or "bias" (-inf for skip token, 0 or others) + :return: a tuple of 3: + - tuple of crow_indices, col_indices representation + of CSR format. + - block dense mask + - all token dense mask (be aware that it can be OOM if it + is too big) if `return_dense==True`, otherwise, None + """ + assert dense_mask_type in ("binary", "bias") + if homo_head: + with torch.no_grad(): + (crow, col), block_mask_dense, mask_dense = ( + _get_sparse_attn_mask_homo_head( + q_len, + max_seqlen, + dtype, + device, + block_size, + local_blocks, + vert_stride, + return_dense, + )) + crow = crow[None].expand(n_heads, crow.shape[0]) + col = col[None].expand(n_heads, col.shape[0]) + if return_dense: + mask_dense = mask_dense[None].expand(n_heads, + *mask_dense.shape) + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + return (crow, col), block_mask_dense, mask_dense + + with torch.no_grad(): + num_blocks = triton.cdiv(max_seqlen, block_size) + q_pos = torch.arange(num_blocks)[None, :, None] + k_pos = torch.arange(num_blocks)[None, None] + head_sliding_step = get_head_sliding_step(n_heads, vert_stride) + mask_vert_strided = [ + (torch.arange(num_blocks) + h * head_sliding_step + 1) % + vert_stride == 0 for h in range(n_heads) + ] + mask_vert_strided = torch.vstack(mask_vert_strided).unsqueeze(1) + block_mask_dense = (((q_pos >= k_pos) + & ((q_pos - k_pos < local_blocks) + | mask_vert_strided)).to(device).to(dtype)) + num_blocks_q = triton.cdiv(q_len, block_size) + block_mask_dense_output = block_mask_dense[:, -num_blocks_q:] + if return_dense: + mask_dense = torch.kron( + block_mask_dense, + block_mask_dense.new_ones((block_size, block_size)), + ) + causal_mask = torch.tril(torch.ones( + max_seqlen, max_seqlen)).type_as(mask_dense)[-q_len:] + mask_dense = mask_dense[..., -q_len:, :max_seqlen] * causal_mask[None] + if dense_mask_type == "bias": + mask_dense = binary_mask_to_bias(mask_dense) + + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + mask_dense, + ) + else: + return ( + dense_to_crow_col(block_mask_dense_output), + block_mask_dense, + None, + ) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py new file mode 100644 index 0000000..6b270ff --- /dev/null +++ b/vllm/attention/ops/ipex_attn.py @@ -0,0 +1,123 @@ +from typing import Dict, List, Optional, Tuple + +import intel_extension_for_pytorch.llm.modules as ipex_modules +import torch + +from vllm import _custom_ops as ops + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + *args, + ) -> torch.Tensor: + output = torch.empty_like(query) + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + prompt_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_subquery_len: int, + alibi_slopes: Optional[torch.Tensor], + *args, + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + *args, + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py new file mode 100644 index 0000000..1741dd1 --- /dev/null +++ b/vllm/attention/ops/paged_attn.py @@ -0,0 +1,242 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 + + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + output = torch.empty_like(query) + block_size = value_cache.shape[3] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_seq_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + use_v1 = True + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache_dtype: str, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + query_start_loc: torch.Tensor, + seq_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_query_len: int, + alibi_slopes: Optional[torch.Tensor], + sliding_window: Optional[int], + k_scale: float, + v_scale: float, + ) -> torch.Tensor: + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + kv_cache_dtype, + key_cache, + value_cache, + block_tables, + # query_start_loc is (batch_size + 1,) + query_start_loc[:-1], + seq_lens_tensor, + context_lens, + max_query_len, + k_scale, + v_scale, + alibi_slopes, + sliding_window, + ) + return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py new file mode 100644 index 0000000..9a39e2b --- /dev/null +++ b/vllm/attention/ops/prefix_prefill.py @@ -0,0 +1,865 @@ +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch +import triton +import triton.language as tl + +from vllm.platforms import current_platform + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + SLIDING_WINDOW: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + + # start position inside of the query + # generally, N goes over kv, while M goes over query_len + block_start_loc = BLOCK_M * start_m + + # initialize offsets + # [N]; starts at 0 + offs_n = tl.arange(0, BLOCK_N) + # [D]; starts at 0 + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + # [M]; starts at current position in query + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # [M,D] + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, + 0).to(tl.int1) # [D] + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len), + other=0.0) # [M,D] + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # [M] + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # [M] + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], + dtype=tl.float32) # [M,D] + + # compute query against context (no causal mask here) + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) # [N] + # [D,N] + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + # [N,D] + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + if SLIDING_WINDOW > 0: + # (cur_batch_ctx_len + offs_m[:, None]) are the positions of + # Q entries in sequence + # (start_n + offs_n[None, :]) are the positions of + # KV entries in sequence + # So the condition makes sure each entry in Q only attends + # to KV entries not more than SLIDING_WINDOW away. + # + # We can't use -inf here, because the + # sliding window may lead to the entire row being masked. + # This then makes m_ij contain -inf, which causes NaNs in + # exp(). + qk = tl.where((cur_batch_ctx_len + offs_m[:, None]) - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, + -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) # [M] + p = tl.exp(qk - m_ij[:, None]) # [M,N] + l_ij = tl.sum(p, 1) # [M] + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) # [M] + alpha = tl.exp(m_i - m_i_new) # [M] + beta = tl.exp(m_ij - m_i_new) # [M] + l_i_new = alpha * l_i + beta * l_ij # [M] + + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) # [N,D] + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + # block_mask is 0 when we're already past the current query length + block_mask = tl.where(block_start_loc < cur_batch_query_len, 1, 0) + + # compute query against itself (with causal mask) + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_query_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + # apply causal mask + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + if SLIDING_WINDOW > 0: + qk = tl.where( + offs_m[:, None] - + (start_n + offs_n[None, :]) < SLIDING_WINDOW, qk, -10000) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_query_len), + other=0.0) + p = p.to(v.dtype) + + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_query_len)) + return + + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + k_scale, + v_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, # head size + BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL_PADDED) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + dim_mask = tl.where( + tl.arange(0, BLOCK_DMODEL_PADDED) < BLOCK_DMODEL, 1, 0).to(tl.int1) + + q = tl.load(Q + off_q, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k_load = tl.load(K_cache + off_k, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < cur_batch_ctx_len), + other=0.0) # [D,N] + + if k_load.dtype.is_fp8(): + k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + else: + k = k_load + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v_load = tl.load(V_cache + off_v, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < cur_batch_ctx_len), + other=0.0) + if v_load.dtype.is_fp8(): + v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + else: + v = v_load + p = p.to(v.dtype) + + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=dim_mask[:, None] & + ((start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=dim_mask[None, :] & + ((start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len), + other=0.0) + p = p.to(v.dtype) + + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=dim_mask[None, :] & + (offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)) + return + + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + kv_cache_dtype: str, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + k_scale: float = 1.0, + v_scale: float = 1.0, + alibi_slopes=None, + sliding_window=None): + + BLOCK = 128 if current_platform.has_device_capability(80) else 64 + NUM_WARPS = 8 + + # need to reduce num. blocks when using fp32 + # due to increased use of GPU shared memory + if q.dtype is torch.float32: + BLOCK = BLOCK // 2 + + # Conversion of FP8 Tensor from uint8 storage to + # appropriate torch.dtype for interpretation by Triton + if "fp8" in kv_cache_dtype: + assert (k_cache.dtype == torch.uint8) + assert (v_cache.dtype == torch.uint8) + + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + target_dtype = torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + target_dtype = torch.float8_e5m2 + else: + raise ValueError("Unsupported FP8 dtype:", kv_cache_dtype) + + k_cache = k_cache.view(target_dtype) + v_cache = v_cache.view(target_dtype) + + if (k_cache.dtype == torch.uint8 + or v_cache.dtype == torch.uint8 and kv_cache_dtype == "auto"): + raise ValueError("kv_cache_dtype='auto' unsupported for\ + FP8 KV Cache prefill kernel") + + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + # round up Lk to a power of 2 - this is required for Triton block size + Lk_padded = triton.next_power_of_2(Lk) + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + # 0 means "disable" + if sliding_window is None or sliding_window <= 0: + sliding_window = 0 + + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + num_warps=NUM_WARPS, + num_stages=1, + ) + return + + import time + ts_beg = time.time() + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + k_scale, + v_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + k_cache.shape[4], + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_DMODEL_PADDED=Lk_padded, + BLOCK_N=BLOCK, + SLIDING_WINDOW=sliding_window, + num_warps=NUM_WARPS, + num_stages=1, + ) + elapsed = time.time() - ts_beg + #print(f'{elapsed}: {BLOCK=}, {Lk=}, {Lk_padded=}, {BLOCK=}, {sliding_window=}, {NUM_WARPS=}') + return diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py new file mode 100644 index 0000000..f942111 --- /dev/null +++ b/vllm/attention/ops/triton_flash_attention.py @@ -0,0 +1,820 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao +(https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype: tl.constexpr = torch.float16 + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, + stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, + stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + + +@triton.jit +def _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr, +): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn( + K_block_ptr, + PADDED_HEAD, + MASK_STEPS and (n_extra_tokens != 0), + "zero", + ) + if PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: # noqa: SIM102 + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps + # if not is_modulo_mn. last step might get wasted but that is okay. + # check if this masking works for that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], + actual_seqlen_k, + dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS + and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, our + # optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += bias * 1.44269504089 + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = (batch_philox_offset + + start_m * BLOCK_M * actual_seqlen_k + start_n - + BLOCK_N) + keep = dropout_mask( + philox_seed, + philox_offset, + dropout_p, + BLOCK_M, + BLOCK_N, + actual_seqlen_k, + ) + if RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + tl.where(keep, p, + -p).to(encoded_softmax_block_ptr.type.element_ty), + ) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store( + encoded_softmax_block_ptr, + p.to(encoded_softmax_block_ptr.type.element_ty), + ) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn( + V_block_ptr, + MASK_STEPS and (n_extra_tokens != 0), + PADDED_HEAD, + "zero", + ) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, BLOCK_N)) + return acc, l_i, m_i + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 64, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 256, + "BLOCK_N": 128, + "waves_per_eu": 2, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": True, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 128, + "BLOCK_N": 64, + "waves_per_eu": 3, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + triton.Config( + { + "BLOCK_M": 64, + "BLOCK_N": 64, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + triton.Config( + { + "BLOCK_M": 32, + "BLOCK_N": 32, + "waves_per_eu": 4, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=8, + ), + # TODO: This config fails with head_size not pow2 with data mismatches. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, + # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config( + { + "BLOCK_M": 16, + "BLOCK_N": 16, + "waves_per_eu": 1, + "PRE_LOAD_V": False, + }, + num_stages=1, + num_warps=4, + ), + ], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +) +@triton.jit +def attn_fwd( + Q, + K, + V, + bias, + sm_scale, + L, + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_bz, + stride_bh, + stride_bm, + stride_bn, + cu_seqlens_q, + cu_seqlens_k, + dropout_p, + philox_seed, + philox_offset_base, + encoded_softmax, + HQ: tl.constexpr, + HK: tl.constexpr, + ACTUAL_BLOCK_DMODEL: tl.constexpr, + MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if IS_CAUSAL: + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn + # matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is + # part of the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + # tl.store(O_block_ptr, + # acc.to(Out.type.element_ty), boundary_check=(0,1)) + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # + offs_m + # We store inf to LSE, not -inf because in the bwd pass, + # we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 + # for these masked blocks. + # l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + # tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here? + return + + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL + + # Compute pointers for all the tensors used in this kernel. + q_offset = (off_z * stride_qz + off_h_q * stride_qh + + cu_seqlens_q_start * stride_qm) + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_offset = (off_z * stride_kz + off_h_k * stride_kh + + cu_seqlens_k_start * stride_kn) + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1), + ) + v_offset = (off_z * stride_vz + off_h_k * stride_vh + + cu_seqlens_k_start * stride_vk) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base \ + + (off_z * HQ + off_h_q) \ + * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. + # In this case, we return an invalid pointer so indicate the mask is not i + # valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional + # block. In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, + block_max, + 0, + 0, + 0, + bias_ptr, + # IS_CAUSAL, .... + False, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + False, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if masked_blocks > 0: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, + l_i, + m_i, + q, + K_block_ptr, + V_block_ptr, + start_m, + seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, + block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL, + BLOCK_M, + BLOCK_DMODEL, + BLOCK_N, + offs_m, + offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, + True, + ENABLE_DROPOUT, + RETURN_ENCODED_SOFTMAX, + padded_head, + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: # noqa: SIM102 + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), + causal_start_idx, + dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = (mask_m_offsets[:, None] >= + out_mask_boundary[None, :]) + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last + # few rows. This is only true for the last M block. For others, + # overflow_size will be -ve + # overflow_size = end_m_idx - seqlen_q + # if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + # else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om + + off_h_q * stride_oh) + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0, 1)) + + +def check_args( + q, + k, + v, + o, + varlen=True, + max_seqlens=None, + cu_seqlens_q=None, + cu_seqlens_k=None, +): + assert q.dim() == k.dim() and q.dim() == v.dim() + if varlen: + assert q.dim() == 3 + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + assert cu_seqlens_q is not None + assert cu_seqlens_k is not None + assert len(cu_seqlens_q) == len(cu_seqlens_k) + else: + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + o, + cu_seqlens_q, + cu_seqlens_k, + max_seqlens_q, + max_seqlens_k, + causal=False, + sm_scale=1.0, + bias=None, + ): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + + check_args( + q, + k, + v, + o, + varlen=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + ) + if True: # varlen + total_q, nheads_q, head_size = q.shape + total_k, nheads_k, _ = k.shape + batch = len(cu_seqlens_q) - 1 + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + else: + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128, 256} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + grid = lambda META: ( + triton.cdiv(max_seqlens_q, META["BLOCK_M"]), + nheads_q, + batch, + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = ( + bias.stride(0), + bias.stride(1), + bias.stride(2), + bias.stride(3), + ) + else: + bias_strides = (0, 0, 0, 0) + + attn_fwd[grid]( + q, + k, + v, + bias, + sm_scale, + None, + o, + *q_strides, + *k_strides, + *v_strides, + *o_strides, + *bias_strides, + cu_seqlens_q, + cu_seqlens_k, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + HQ=nheads_q, + HK=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=max_seqlens_q, + MAX_SEQLENS_K=max_seqlens_k, + IS_CAUSAL=causal, + VARLEN=True, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False, + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + + +triton_attention = _attention.apply diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py new file mode 100644 index 0000000..d1a0bd2 --- /dev/null +++ b/vllm/attention/selector.py @@ -0,0 +1,308 @@ +import enum +import os +from contextlib import contextmanager +from functools import lru_cache +from typing import Generator, Optional, Type + +import torch + +import vllm.envs as envs +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu + +logger = init_logger(__name__) + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + NO_ATTENTION = enum.auto() + + +def backend_name_to_enum(backend_name: str) -> _Backend: + assert backend_name is not None + + backend_members = _Backend.__members__ + if backend_name not in backend_members: + raise ValueError(f"Invalid attention backend '{backend_name}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + + return _Backend[backend_name] + + +def get_env_variable_attn_backend() -> Optional[_Backend]: + ''' + Get the backend override specified by the vLLM attention + backend environment variable, if one is specified. + + Returns: + + * _Backend enum value if an override is specified + * None otherwise + ''' + backend_name = os.environ.get(STR_BACKEND_ENV_VAR) + return (None + if backend_name is None else backend_name_to_enum(backend_name)) + + +# Global state allows a particular choice of backend +# to be forced, overriding the logic which auto-selects +# a backend based on system & workload configuration +# (default behavior if this variable is None) +# +# THIS SELECTION TAKES PRECEDENCE OVER THE +# VLLM ATTENTION BACKEND ENVIRONMENT VARIABLE +forced_attn_backend: Optional[_Backend] = None + + +def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None: + ''' + Force all attention operations to use a specified backend. + + Passing `None` for the argument re-enables automatic + backend selection., + + Arguments: + + * attn_backend: backend selection (None to revert to auto) + ''' + global forced_attn_backend + forced_attn_backend = attn_backend + + +def get_global_forced_attn_backend() -> Optional[_Backend]: + ''' + Get the currently-forced choice of attention backend, + or None if auto-selection is currently enabled. + ''' + return forced_attn_backend + + +@lru_cache(maxsize=None) +def get_attn_backend( + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, +) -> Type[AttentionBackend]: + """Selects which attention backend to use and lazily imports it.""" + + # if is_blocksparse: + # logger.info("Using BlocksparseFlashAttention backend.") + # from vllm.attention.backends.blocksparse_attn import ( + # BlocksparseFlashAttentionBackend) + # return BlocksparseFlashAttentionBackend + + backend = which_attn_to_use(head_size, sliding_window, dtype, + kv_cache_dtype, block_size, is_attention_free) + if backend == _Backend.FLASH_ATTN: + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + return FlashAttentionBackend + if backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + from vllm.attention.backends.xformers import ( # noqa: F401 + XFormersBackend) + return XFormersBackend + elif backend == _Backend.ROCM_FLASH: + logger.info("Using ROCmFlashAttention backend.") + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + elif backend == _Backend.TORCH_SDPA: + assert is_cpu(), RuntimeError( + "Torch SDPA backend is only used for the CPU device.") + logger.info("Using Torch SDPA backend.") + from vllm.attention.backends.torch_sdpa import TorchSDPABackend + return TorchSDPABackend + elif backend == _Backend.OPENVINO: + logger.info("Using OpenVINO Attention backend.") + from vllm.attention.backends.openvino import OpenVINOAttentionBackend + return OpenVINOAttentionBackend + elif backend == _Backend.IPEX: + assert is_xpu(), RuntimeError( + "IPEX attention backend is only used for the XPU device.") + logger.info("Using IPEX attention backend.") + from vllm.attention.backends.ipex_attn import IpexAttnBackend + return IpexAttnBackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + from vllm.attention.backends.flashinfer import FlashInferBackend + return FlashInferBackend + elif backend == _Backend.PALLAS: + logger.info("Using Pallas backend.") + from vllm.attention.backends.pallas import PallasAttentionBackend + return PallasAttentionBackend + elif backend == _Backend.NO_ATTENTION: + from vllm.attention.backends.placeholder_attn import ( + PlaceholderAttentionBackend) + return PlaceholderAttentionBackend + else: + raise ValueError("Invalid attention backend.") + + +def which_attn_to_use( + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, +) -> _Backend: + """Returns which flash attention backend to use.""" + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # If there are no attention layers (e.g. we are running Mamba), + # use the placeholder NO_ATTENTION + if is_attention_free: + return _Backend.NO_ATTENTION + + # Check whether a particular choice of backend was + # previously forced. + # + # THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND + # ENVIRONMENT VARIABLE. + backend_by_global_setting: Optional[_Backend] = ( + get_global_forced_attn_backend()) + if backend_by_global_setting is not None: + selected_backend = backend_by_global_setting + else: + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + + if is_cpu(): + if selected_backend != _Backend.TORCH_SDPA: + logger.info("Cannot use %s backend on CPU.", selected_backend) + return _Backend.TORCH_SDPA + + if is_openvino(): + if selected_backend != _Backend.OPENVINO: + logger.info("Cannot use %s backend on OpenVINO.", selected_backend) + return _Backend.OPENVINO + + if is_xpu(): + if selected_backend != _Backend.IPEX: + logger.info("Cannot use %s backend on XPU.", selected_backend) + return _Backend.IPEX + + if current_platform.is_tpu(): + if selected_backend != _Backend.PALLAS: + logger.info("Cannot use %s backend on TPU.", selected_backend) + return _Backend.PALLAS + + if selected_backend == _Backend.FLASH_ATTN: + print("selected_backend == _Backend.FLASH_ATTN") + + if is_hip(): + # AMD GPUs. + selected_backend = (_Backend.ROCM_FLASH if selected_backend + == _Backend.FLASH_ATTN else selected_backend) + if selected_backend == _Backend.ROCM_FLASH: + if not current_platform.has_device_capability(90): + # not Instinct series GPUs. + logger.info("flash_attn is not supported on NAVI GPUs.") + else: + logger.info("%s is not supported in AMD GPUs.", selected_backend) + return _Backend.ROCM_FLASH + + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if not current_platform.has_device_capability(80): + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): + logger.info( + "Cannot use FlashAttention-2 backend for FP8 KV cache.") + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by setting environment variable " + "VLLM_ATTENTION_BACKEND=FLASHINFER") + selected_backend = _Backend.XFORMERS + elif block_size % 16 != 0: + logger.info( + "Cannot use FlashAttention-2 backend for block size not " + "divisible by 16.") + selected_backend = _Backend.XFORMERS + # elif sliding_window is not None: + # logger.info( + # "Cannot use FlashAttention-2 backend due to sliding window.") + # selected_backend = _Backend.XFORMERS + + # FlashAttn is valid for the model, checking if the package is installed. + if selected_backend == _Backend.FLASH_ATTN: + try: + # import vllm_flash_attn # noqa: F401 + + from vllm.attention.backends.flash_attn import ( # noqa: F401 + FlashAttentionBackend) + + supported_sizes = FlashAttentionBackend.get_supported_head_sizes() + if head_size not in supported_sizes: + logger.info( + "Cannot use FlashAttention-2 backend for head size %d.", + head_size) + selected_backend = _Backend.XFORMERS + except ImportError: + logger.info( + "Cannot use FlashAttention-2 backend because the " + "vllm.vllm_flash_attn package is not found. " + "Make sure that vllm_flash_attn was built and installed " + "(on by default).") + selected_backend = _Backend.XFORMERS + + return selected_backend + + +@contextmanager +def global_force_attn_backend_context_manager( + attn_backend: _Backend) -> Generator[None, None, None]: + ''' + Globally force a vLLM attention backend override within a + context manager, reverting the global attention backend + override to its prior state upon exiting the context + manager. + + Arguments: + + * attn_backend: attention backend to force + + Returns: + + * Generator + ''' + + # Save the current state of the global backend override (if any) + original_value = get_global_forced_attn_backend() + + # Globally force the new backend override + global_force_attn_backend(attn_backend) + + # Yield control back to the enclosed code block + try: + yield + finally: + # Revert the original global backend override, if any + global_force_attn_backend(original_value) diff --git a/vllm/beam_search.py b/vllm/beam_search.py new file mode 100644 index 0000000..04624b8 --- /dev/null +++ b/vllm/beam_search.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class BeamSearchSequence: + """A sequence for beam search. + It keeps track of the tokens and the log probability of the sequence. + The text field is optional and will only be filled when the sequence is + about to be returned to the user. + """ + # The tokens includes the prompt. + tokens: List[int] + cum_logprob: float = 0.0 + text: Optional[str] = None + + +@dataclass +class BeamSearchOutput: + """The output of beam search. + It contains the list of the best beam search sequences. + The length of the list is equal to the beam width. + """ + sequences: List[BeamSearchSequence] + + +class BeamSearchInstance: + + def __init__(self, prompt_tokens: List[int]): + self.beams: List[BeamSearchSequence] = [ + BeamSearchSequence(tokens=prompt_tokens) + ] + self.completed: List[BeamSearchSequence] = [] + + +def get_beam_search_score( + tokens: List[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) + + +def create_sort_beams_key_function(eos_token_id: int, length_penalty: float): + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, eos_token_id, + length_penalty) + + return sort_beams_key diff --git a/vllm/block.py b/vllm/block.py new file mode 100644 index 0000000..47c381c --- /dev/null +++ b/vllm/block.py @@ -0,0 +1,88 @@ +"""Token blocks.""" +from typing import TYPE_CHECKING, Iterator, List, Optional + +from vllm.utils import Device + +DEFAULT_LAST_ACCESSED_TIME: float = -1 + + +class PhysicalTokenBlock: + """Represents the state of a block in the KV cache.""" + + def __init__( + self, + device: Device, + block_number: int, + block_size: int, + block_hash: int, + num_hashed_tokens: int, + ) -> None: + self.device = device + self.block_number = block_number + self.block_size = block_size + self.block_hash = block_hash + self.num_hashed_tokens = num_hashed_tokens + + self.ref_count = 0 + self.last_accessed = DEFAULT_LAST_ACCESSED_TIME + + self.computed = False + + def __repr__(self) -> str: + return (f'PhysicalTokenBlock(device={self.device}, ' + f'block_number={self.block_number}, ' + f'num_hashed_tokens={self.num_hashed_tokens}, ' + f'ref_count={self.ref_count}, ' + f'last_accessed={self.last_accessed}, ' + f'computed={self.computed})') + + +class BlockTable: + """Holds a list of blocks with caching of their associated block_ids + """ + + def __init__(self, blocks: Optional[List[PhysicalTokenBlock]] = None): + self._blocks: List[PhysicalTokenBlock] = [] + self._block_ids: List[int] = [] + + if blocks is not None: + for block in blocks: + self.append(block) + + def append(self, block: PhysicalTokenBlock): + self._blocks.append(block) + self._block_ids.append(block.block_number) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, key): + return self._blocks[key] + + if TYPE_CHECKING: + + def __iter__(self) -> Iterator[PhysicalTokenBlock]: + raise RuntimeError("Method should be automatically generated") + + def __setitem__(self, key, value): + if isinstance(key, slice): + blocks = value + self._blocks[key] = blocks + self._block_ids[key] = [b.block_number for b in blocks] + else: + block = value + self._blocks[key] = block + self._block_ids[key] = block.block_number + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def copy(self) -> "BlockTable": + return BlockTable(self._blocks) + + def list(self) -> List[PhysicalTokenBlock]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/compilation/__pycache__/__init__.cpython-310.pyc b/vllm/compilation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..004a088 Binary files /dev/null and b/vllm/compilation/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/compilation/__pycache__/backends.cpython-310.pyc b/vllm/compilation/__pycache__/backends.cpython-310.pyc new file mode 100644 index 0000000..1b80e58 Binary files /dev/null and b/vllm/compilation/__pycache__/backends.cpython-310.pyc differ diff --git a/vllm/compilation/__pycache__/compile_context.cpython-310.pyc b/vllm/compilation/__pycache__/compile_context.cpython-310.pyc new file mode 100644 index 0000000..dc22671 Binary files /dev/null and b/vllm/compilation/__pycache__/compile_context.cpython-310.pyc differ diff --git a/vllm/compilation/__pycache__/decorators.cpython-310.pyc b/vllm/compilation/__pycache__/decorators.cpython-310.pyc new file mode 100644 index 0000000..8a478e6 Binary files /dev/null and b/vllm/compilation/__pycache__/decorators.cpython-310.pyc differ diff --git a/vllm/compilation/__pycache__/levels.cpython-310.pyc b/vllm/compilation/__pycache__/levels.cpython-310.pyc new file mode 100644 index 0000000..1a93d5b Binary files /dev/null and b/vllm/compilation/__pycache__/levels.cpython-310.pyc differ diff --git a/vllm/compilation/__pycache__/wrapper.cpython-310.pyc b/vllm/compilation/__pycache__/wrapper.cpython-310.pyc new file mode 100644 index 0000000..8eeb143 Binary files /dev/null and b/vllm/compilation/__pycache__/wrapper.cpython-310.pyc differ diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py new file mode 100644 index 0000000..4780358 --- /dev/null +++ b/vllm/compilation/backends.py @@ -0,0 +1,269 @@ +import copy +import operator +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.fx as fx + +from vllm.logger import init_logger + +from .compile_context import get_compile_context +from .levels import CompilationLevel + +logger = init_logger(__name__) + + +def fix_functionalization(graph: fx.Graph): + """ + Rewrite the graph module to replace the pattern involving + torch._higher_order_ops.auto_functionalize.auto_functionalized + with a direct call to the inplace custom op. + + # TODO: check if PyTorch nightly has fixed this issue + """ + + # debug code, if we want to see the graph before the transformation + # with open("before.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + nodes_to_remove = [] + + for node in graph.nodes: + # Identify the auto_functionalized node + if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa + if node.args[0] == torch.ops._C.rotary_embedding.default: + # manual replace for rotary_embedding + + # Now, collect the arguments + kwargs = node.kwargs + + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function(torch.ops._C.rotary_embedding.default, + kwargs=kwargs) + + # Remove the auto_functionalized node + # Since the node may have outputs, we need to handle its users + # Replace uses of the outputs (getitem nodes) with mm_node + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + for getitem_user in list(user.users): + if (getitem_user.op == 'call_function' + and getitem_user.target + == torch.ops.aten.slice_scatter.default): + # Replace the uses of slice_scatter node + # with mm_node + getitem_user.replace_all_uses_with(mm_node) + nodes_to_remove.append(getitem_user) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: + # manual replace for fused_add_rms_norm + # this is the most effective optimization for llama + # failing to do this will result in many unnecessary copies + + kwargs = node.kwargs + + input = kwargs['input'] + residual = kwargs['residual'] + + # Create a new call to torch.ops._C.rotary_embedding.default + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + # Remove the getitem node + if user.args[1] == 1: + replace_node = input + elif user.args[1] == 2: + replace_node = residual + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.rms_norm.default: + # manual replace for rms_norm + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + weight = kwargs['weight'] + epsilon = kwargs['epsilon'] + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.rms_norm.default, + args=(out, input, weight, epsilon), + ) + + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + elif node.args[0] == torch.ops._C.silu_and_mul.default: + # manual replace for silu_and_mul + + kwargs = node.kwargs + + input = kwargs['input'] + out = kwargs['out'] + + # Create a new call to torch.ops._C.rotary_embedding.default + # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa + with graph.inserting_before(node): + # just insert the call to the custom op + # NOTE: don't run dead code elimination, + # otherwise this op will be removed + graph.call_function( + torch.ops._C.silu_and_mul.default, + args=(out, input), + ) + replace_node = out + + for user in list(node.users): + if user.op == 'call_function' and user.target == operator.getitem: # noqa + user.replace_all_uses_with(replace_node) + nodes_to_remove.append(user) + nodes_to_remove.append(node) + + # Remove the nodes all at once + for node in nodes_to_remove: + graph.erase_node(node) + + # debug code, if we want to see the graph after the transformation + # with open("after.py", "w") as f: + # print(graph.python_code(root_module="self", verbose=True).src, file=f) + + +def wrap_inductor(graph, example_inputs, additional_inductor_config): + from torch._inductor import config + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + + if additional_inductor_config is not None: + current_config.update(additional_inductor_config) + if current_config['post_grad_custom_post_pass'] is not None: + logger.warning( + "post_grad_custom_post_pass is already set in the config. " + "Overwriting it with the fix_functionalization") + current_config['post_grad_custom_post_pass'] = fix_functionalization + return compile_fx(graph, example_inputs, config_patches=current_config) + + +def vllm_backend( + graph, + example_inputs, + additional_inductor_config: Optional[Dict] = None) -> Callable: + + context = get_compile_context() + context = copy.deepcopy(context) if context is not None else [] + sizes_to_specialize: List[int] = context + + # flags for all the seen shapes, whether we need to specialize + runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {} + + # if we need to specialize, the compiled graph for that shape + runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {} + + # this is the first compilation, we will compile a graph with + # dynamic shape, as the caller will mark first dimension as dynamic + logger.info("Compiling a graph for general shapes") + graph_for_symbolic_shape = wrap_inductor(graph, example_inputs, + additional_inductor_config) + + # TODO: Dynamo does not pass all dynamic shapes. + # Need to investigate why. It works now because all the dynamic + # shapes have the same value, and either of them can be used. + sym_shape_indices = [ + i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt) + ] + + first_run = True + + # this is the function we return to Dynamo to run finally + def compiled_graph_wrapper(*args): + + runtime_shapes: Tuple[int, + ...] = tuple(args[i] for i in sym_shape_indices) + + nonlocal first_run + nonlocal runtime_shapes_to_compile_flags + nonlocal runtime_shapes_to_compiled_graph + + if first_run: + # the first compilation is for profiling, we directly run it + first_run = False + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compile_flags: + # we haven't seen this shape before + # query if we need to specialize for this shape + # we only specialize for the first dimension. + # TODO: investigate if any model needs to specialize + # beyond the first dimension + runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[ + 0] in sizes_to_specialize + + if not runtime_shapes_to_compile_flags[runtime_shapes]: + # we don't need to specialize for this shape + return graph_for_symbolic_shape(*args) + + if runtime_shapes not in runtime_shapes_to_compiled_graph: + # we need to specialize for this shape, and we haven't compiled + # compile the graph for this shape + logger.info("Compiling a graph for shapes %s", runtime_shapes) + runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor( + graph, args, additional_inductor_config) + + return runtime_shapes_to_compiled_graph[runtime_shapes](*args) + + return compiled_graph_wrapper + + +def select_default_backend(level: int) -> Union[str, Callable]: + if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]: + backend = "eager" + return backend + assert level in [ + CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE + ], f"Invalid level {level}" + + from vllm.compilation.backends import vllm_backend + from vllm.plugins import get_inductor_additional_configs + additional_configs = get_inductor_additional_configs() + + if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE: + if "max_autotune" in additional_configs and not additional_configs[ + "max_autotune"]: + logger.warning( + "max_autotune is disabled, but is overridden by level %s", + CompilationLevel.INDUCTOR_MAX_AUTOTUNE) + additional_configs['max_autotune'] = True + + from functools import partial + backend = partial(vllm_backend, + additional_inductor_config=additional_configs) + + return backend diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py new file mode 100644 index 0000000..29db3d4 --- /dev/null +++ b/vllm/compilation/compile_context.py @@ -0,0 +1,23 @@ +from contextlib import contextmanager +from typing import Any + +_compile_context: Any = None + + +def get_compile_context() -> Any: + """Get the current compile context.""" + return _compile_context + + +@contextmanager +def set_compile_context(context: Any): + """A context manager that stores the current compile context, + usually it is a list of sizes to specialize. + """ + global _compile_context + prev_context = _compile_context + _compile_context = context + try: + yield + finally: + _compile_context = prev_context diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py new file mode 100644 index 0000000..655c4c4 --- /dev/null +++ b/vllm/compilation/decorators.py @@ -0,0 +1,113 @@ +import inspect +from typing import Dict, List, Union + +import torch + +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.sequence import IntermediateTensors +from vllm.utils import supports_dynamo + + +def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]): + """ + A decorator to add support for compiling the forward method of a class. + + `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic + dimensions of the argument. The dynamic dimensions can be either a single + integer or a list of integers. + + Depending on the value of arguments: + + - if it is a single integer, the corresponding dimension of the argument + will be marked as dynamic. + - if it is `None`, ignored. + - if it is `IntermediateTensors`, all the tensors in the intermediate + tensors will be marked as dynamic. + - otherwise, it will raise an error. + + NOTE: if an argument is `None`, it should always be passed as `None` during + the lifetime of the model, otherwise, it cannot be captured as a single + computation graph. + """ + + def cls_decorator_helper(cls: type): + # helper to pass `dynamic_arg_dims`` to `_support_torch_compile`` + # to avoid too much indentation for `_support_torch_compile`` + sig = inspect.signature(cls.forward) + for k in dynamic_arg_dims: + if k not in sig.parameters: + raise ValueError( + f"Argument {k} not found in the forward method of {cls}") + return _support_torch_compile(cls, dynamic_arg_dims) + + return cls_decorator_helper + + +def _support_torch_compile(cls: type, + dynamic_arg_dims: Dict[str, Union[int, List[int]]]): + """ + A decorator to add support for compiling the forward method of a class. + """ + + # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner + # will handle the compilation, so we don't need to do anything here. + if envs.VLLM_TORCH_COMPILE_LEVEL in [ + CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS + ] or not supports_dynamo(): + return cls + + # take care of method resolution order + # make sure super().__init__ is called on the base class + # other than TorchCompileWrapperWithCustomDispatcher + cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, ) + + old_init = cls.__init__ + + def __init__(self, *args, **kwargs): + old_init(self, *args, **kwargs) + TorchCompileWrapperWithCustomDispatcher.__init__(self) + + cls.__init__ = __init__ + + def __call__(self, *args, **kwargs): + # torch.compiler.is_compiling() means we are inside the compilation + # e.g. TPU has the compilation logic in model runner, so we don't + # need to compile the model inside. + if torch.compiler.is_compiling(): + return self.forward(*args, **kwargs) + + # the first compilation needs to have dynamic shapes marked + if len(self.compiled_codes) < 1: + sig = inspect.signature(self.__class__.forward) + bound_args = sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for k, dims in dynamic_arg_dims.items(): + arg = bound_args.arguments.get(k) + if arg is not None: + if isinstance(arg, torch.Tensor): + torch._dynamo.mark_dynamic(arg, dims) + elif isinstance(arg, IntermediateTensors): + for tensor in arg.tensors.values(): + torch._dynamo.mark_dynamic(tensor, dims) + else: + raise ValueError( + "Unsupported dynamic dimensions" + f" {dims} for argument {k} with type {type(arg)}.") + + # if we don't use custom dispatcher, we can directly call the + # compiled function and let torch.compile handle the dispatching, + # with the overhead of guard evaluation and recompilation. + if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher: + return self.compiled_callable(*args, **kwargs) + + # usually, capturing the model once is enough, and then we can + # dispatch to the compiled code directly, without going through + # the Dynamo guard mechanism. + with self.dispatch_to_code(0): + model_output = self.forward(*args, **kwargs) + return model_output + + cls.__call__ = __call__ + return cls diff --git a/vllm/compilation/levels.py b/vllm/compilation/levels.py new file mode 100644 index 0000000..162bf5a --- /dev/null +++ b/vllm/compilation/levels.py @@ -0,0 +1,9 @@ +# constants for the levels of the compilation process + + +class CompilationLevel: + NO_COMPILATION = 0 + DYNAMO_AS_IS = 1 + DYNAMO_ONCE = 2 + INDUCTOR = 3 + INDUCTOR_MAX_AUTOTUNE = 4 diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 0000000..1594b64 --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,102 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List, Optional + +import torch + +import vllm.envs as envs + +from .levels import CompilationLevel + + +class TorchCompileWrapperWithCustomDispatcher: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Optional[Callable] = None): + + if compiled_callable is None: + # default compilation settings + # compiling the forward method + + # choose the compile backend + + # if the user has set the backend, use it + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() + if backend is None: + from vllm.compilation.backends import select_default_backend + backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL) + + compiled_callable = torch.compile( + self.forward, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config.py b/vllm/config.py new file mode 100644 index 0000000..3997bf1 --- /dev/null +++ b/vllm/config.py @@ -0,0 +1,1909 @@ +import enum +import json +from dataclasses import dataclass, field, fields +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, + Optional, Tuple, Type, Union) + +import torch +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.models import ModelRegistry +from vllm.platforms import current_platform +from vllm.tracing import is_otel_available, otel_import_error_traceback +from vllm.transformers_utils.config import (ConfigFormat, get_config, + get_hf_image_processor_config, + get_hf_text_config) +from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, + is_hip, is_neuron, is_openvino, is_xpu, + print_warning_once) + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + + from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.model_loader.loader import BaseModelLoader + from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import ( + BaseTokenizerGroup) + +logger = init_logger(__name__) + +_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, "slow" will always use the slow tokenizer, and + "mistral" will always use the tokenizer from `mistral_common`. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + If None, the user did not specify, so default to False. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. + limit_mm_per_prompt: Maximum number of data instances per modality + per prompt. Only applicable for multimodal models. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. + config_format: The config format which shall be loaded. + Defaults to 'auto' which defaults to 'hf'. + mm_processor_kwargs: Arguments to be forwarded to the model's processor + for multi-modal data, e.g., image processor. + """ + + def __init__(self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + override_neuron_config: Optional[Dict[str, Any]] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, + mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.enforce_eager = enforce_eager + if max_context_len_to_capture is not None: + raise ValueError("`max_context_len_to_capture` is deprecated. " + "Use `max_seq_len_to_capture` instead.") + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.skip_tokenizer_init = skip_tokenizer_init + + self.hf_config = get_config(self.model, trust_remote_code, revision, + code_revision, rope_scaling, rope_theta, + config_format) + self.hf_text_config = get_hf_text_config(self.hf_config) + self.hf_image_processor_config = get_hf_image_processor_config( + self.model, revision) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.use_async_output_proc = use_async_output_proc + self.mm_processor_kwargs = mm_processor_kwargs + + # Set enforce_eager to False if the value is unset. + if self.enforce_eager is None: + self.enforce_eager = False + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=spec_target_max_model_len) + self.served_model_name = get_served_model_name(model, + served_model_name) + self.multimodal_config = self._init_multimodal_config( + limit_mm_per_prompt) + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + + self.is_attention_free = self._init_attention_free() + self.has_inner_state = self._init_has_inner_state() + + self.override_neuron_config = override_neuron_config if is_neuron( + ) else None + self._verify_embedding_mode() + self._verify_quantization() + self._verify_cuda_graph() + self._verify_bnb_config() + + def _init_multimodal_config( + self, limit_mm_per_prompt: Optional[Mapping[str, int]] + ) -> Optional["MultiModalConfig"]: + architectures = getattr(self.hf_config, "architectures", []) + if ModelRegistry.is_multimodal_model(architectures): + return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) + + if limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + + return None + + def _init_attention_free(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_attention_free_model(architectures) + + def _init_has_inner_state(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.model_has_inner_state(architectures) + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow", "mistral"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto', 'slow' or 'mistral'.") + self.tokenizer_mode = tokenizer_mode + + def _verify_embedding_mode(self) -> None: + architectures = getattr(self.hf_config, "architectures", []) + self.embedding_mode = ModelRegistry.is_embedding_model(architectures) + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = [ + "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", + "fbgemm_fp8", "w8a16" + ] + optimized_quantization_methods = [ + "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", + "awq_marlin", "fbgemm_fp8", "compressed_tensors", + "compressed-tensors", "experts_int8" + ] + tpu_supported_quantization = ["tpu_int8"] + neuron_supported_quantization = ["neuron_quant"] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization not in rocm_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in ROCm.") + if current_platform.is_tpu( + ) and self.quantization not in tpu_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in TPU Backend.") + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + if (self.quantization == "awq" and is_hip() + and not envs.VLLM_USE_TRITON_AWQ): + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True + if is_neuron( + ) and self.quantization not in neuron_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in Neuron Backend.") + + def _verify_cuda_graph(self) -> None: + if self.max_seq_len_to_capture is None: + self.max_seq_len_to_capture = self.max_model_len + self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, + self.max_model_len) + + def _verify_bnb_config(self) -> None: + """ + The current version of bitsandbytes (0.44.0) with 8-bit models does not + yet support CUDA graph. + """ + is_bitsandbytes = self.quantization == "bitsandbytes" + has_quantization_config = (getattr(self.hf_config, + "quantization_config", None) + is not None) + is_8bit = (self.hf_config.quantization_config.get( + "load_in_8bit", False) if has_quantization_config else False) + if all([ + is_bitsandbytes, + has_quantization_config, + is_8bit, + not self.enforce_eager, + ]): + logger.warning( + "CUDA graph is not supported on BitAndBytes 8bit yet, " + "fallback to the eager mode.") + self.enforce_eager = True + + def verify_async_output_proc(self, parallel_config, speculative_config, + device_config) -> None: + if not self.use_async_output_proc: + # Nothing to check + return + + if parallel_config.pipeline_parallel_size > 1: + logger.warning("Async output processing can not be enabled " + "with pipeline parallel") + self.use_async_output_proc = False + return + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if device_config.device_type not in ("cuda", "tpu", "xpu"): + logger.warning( + "Async output processing is only supported for CUDA, TPU, XPU. " + "Disabling it for other platforms.") + self.use_async_output_proc = False + return + + if envs.VLLM_USE_RAY_SPMD_WORKER: + logger.warning( + "Async output processing can not be enabled with ray spmd") + self.use_async_output_proc = False + return + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if device_config.device_type == "cuda" and self.enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + self.use_async_output_proc = not self.enforce_eager + return + + # Async postprocessor is not necessary with embedding mode + # since there is no token generation + if self.embedding_mode: + self.use_async_output_proc = False + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if speculative_config: + logger.warning("Async output processing is not supported with" + " speculative decoding currently.") + self.use_async_output_proc = False + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_num_attention_heads = getattr(self.hf_text_config, + "num_attention_heads", 0) + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size}).") + + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if pipeline_parallel_size > 1: + architectures = getattr(self.hf_config, "architectures", []) + if not ModelRegistry.is_pp_supported_model(architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") + + if self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False + + def get_hf_config_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + def get_head_size(self) -> int: + # TODO remove hard code + if hasattr(self.hf_text_config, "model_type" + ) and self.hf_text_config.model_type == 'deepseek_v2': + # FlashAttention supports only head_size 32, 64, 128, 256, + # we need to pad head_size 192 to 256 + return 256 + + if self.is_attention_free: + return 0 + + if hasattr(self.hf_text_config, "head_dim"): + return self.hf_text_config.head_dim + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + if self.is_attention_free: + return 0 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_attention_heads(self, + parallel_config: "ParallelConfig") -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads // parallel_config.tensor_parallel_size + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + from vllm.distributed.utils import get_pp_indices + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size + pp_size = parallel_config.pipeline_parallel_size + start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size) + return end - start + + def get_num_attention_layers(self, + parallel_config: "ParallelConfig") -> int: + if self.is_attention_free: + return 0 + + num_layers = self.get_num_layers(parallel_config) + + # Transformers supports layers_block_type @property + layers = getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + return len([t for t in layers if t == "attention"]) + + def get_multimodal_config(self) -> "MultiModalConfig": + """ + Get the multimodal configuration of the model. + + Raises: + ValueError: If the model is not multimodal. + """ + if self.multimodal_config is None: + raise ValueError("The model is not multimodal.") + + return self.multimodal_config + + @property + def is_encoder_decoder_model(self) -> bool: + """Extract the HF encoder/decoder model flag.""" + return getattr(self.hf_config, "is_encoder_decoder", False) or ( + (hasattr(self.hf_config, "text_config") and getattr( + self.hf_config.text_config, "is_encoder_decoder", False))) + + @property + def is_embedding_model(self) -> bool: + """Extract the embedding model flag.""" + return self.embedding_mode + + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the + profiled num_gpu_blocks if specified. Does nothing if None. + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: float, + cache_dtype: str, + is_attention_free: bool = False, + num_gpu_blocks_override: Optional[int] = None, + sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, + cpu_offload_gb: float = 0, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * GiB_bytes + self.num_gpu_blocks_override = num_gpu_blocks_override + self.cache_dtype = cache_dtype + self.is_attention_free = is_attention_free + self.sliding_window = sliding_window + self.enable_prefix_caching = enable_prefix_caching + self.cpu_offload_gb = cpu_offload_gb + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / GiB_bytes:.2f} GiB out of the " + f"{total_cpu_memory / GiB_bytes:.2f} GiB total CPU memory " + "is allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. %s", msg) + + +@dataclass +class TokenizerPoolConfig: + """Configuration for the tokenizer pool. + + Args: + pool_size: Number of tokenizer workers in the pool. + pool_type: Type of the pool. + extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. + """ + pool_size: int + pool_type: Union[str, Type["BaseTokenizerGroup"]] + extra_config: dict + + def __post_init__(self): + if self.pool_type not in ("ray", ) and not isinstance( + self.pool_type, type): + raise ValueError(f"Unknown pool type: {self.pool_type}") + if not isinstance(self.extra_config, dict): + raise ValueError("extra_config must be a dictionary.") + + @classmethod + def create_config( + cls, tokenizer_pool_size: int, tokenizer_pool_type: str, + tokenizer_pool_extra_config: Optional[Union[str, dict]] + ) -> Optional["TokenizerPoolConfig"]: + """Create a TokenizerPoolConfig from the given parameters. + + If tokenizer_pool_size is 0, return None. + + Args: + tokenizer_pool_size: Number of tokenizer workers in the pool. + tokenizer_pool_type: Type of the pool. + tokenizer_pool_extra_config: Additional config for the pool. + The way the config will be used depends on the + pool type. This can be a JSON string (will be parsed). + """ + if tokenizer_pool_size: + if isinstance(tokenizer_pool_extra_config, str): + tokenizer_pool_extra_config_parsed = json.loads( + tokenizer_pool_extra_config) + else: + tokenizer_pool_extra_config_parsed = ( + tokenizer_pool_extra_config or {}) + tokenizer_pool_config = cls(tokenizer_pool_size, + tokenizer_pool_type, + tokenizer_pool_extra_config_parsed) + else: + tokenizer_pool_config = None + return tokenizer_pool_config + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + GGUF = "gguf" + BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + +class ParallelConfig: + """Configuration for the distributed execution. + + Args: + pipeline_parallel_size: Number of pipeline parallel groups. + tensor_parallel_size: Number of tensor parallel groups. + worker_use_ray: Deprecated, use distributed_executor_backend instead. + max_parallel_loading_workers: Maximum number of multiple batches + when load model sequentially. To avoid RAM OOM when using tensor + parallel and large models. + disable_custom_all_reduce: Disable the custom all-reduce kernel and + fall back to NCCL. + tokenizer_pool_config: Config for the tokenizer pool. + If None, will use synchronous tokenization. + ray_workers_use_nsight: Whether to profile Ray workers with nsight, see + https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + placement_group: ray distributed model workers placement group. + distributed_executor_backend: Backend to use for distributed model + workers, either "ray" or "mp" (multiprocessing). If either + pipeline_parallel_size or tensor_parallel_size is greater than 1, + will default to "ray" if Ray is installed or "mp" otherwise. + """ + + def __init__( + self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: Optional[bool] = None, + max_parallel_loading_workers: Optional[int] = None, + disable_custom_all_reduce: bool = False, + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, + ray_workers_use_nsight: bool = False, + placement_group: Optional["PlacementGroup"] = None, + distributed_executor_backend: Optional[Union[ + str, Type["ExecutorBase"]]] = None, + ) -> None: + self.pipeline_parallel_size = pipeline_parallel_size + self.tensor_parallel_size = tensor_parallel_size + self.distributed_executor_backend = distributed_executor_backend + self.max_parallel_loading_workers = max_parallel_loading_workers + self.disable_custom_all_reduce = disable_custom_all_reduce + self.tokenizer_pool_config = tokenizer_pool_config + self.ray_workers_use_nsight = ray_workers_use_nsight + self.placement_group = placement_group + self.world_size = pipeline_parallel_size * self.tensor_parallel_size + + if worker_use_ray: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + elif not self.use_ray: + raise ValueError(f"worker-use-ray can't be used with " + f"distributed executor backend " + f"'{self.distributed_executor_backend}'.") + + if current_platform.is_tpu() and self.world_size > 1: + if self.distributed_executor_backend is None: + self.distributed_executor_backend = "ray" + if self.distributed_executor_backend != "ray": + raise ValueError( + "TPU backend only supports Ray for distributed inference.") + + if self.distributed_executor_backend is None and self.world_size > 1: + # We use multiprocessing by default if world_size fits on the + # current node and we aren't in a ray placement group. + + from vllm.executor import ray_utils + backend = "mp" + ray_found = ray_utils.ray_is_available() + if (current_platform.is_cuda() + and cuda_device_count_stateless() < self.world_size): + if not ray_found: + raise ValueError("Unable to load Ray which is " + "required for multi-node inference, " + "please install Ray with `pip install " + "ray`.") from ray_utils.ray_import_err + backend = "ray" + elif ray_found: + if self.placement_group: + backend = "ray" + else: + from ray import is_initialized as ray_is_initialized + if ray_is_initialized(): + from ray.util import get_current_placement_group + if get_current_placement_group(): + backend = "ray" + self.distributed_executor_backend = backend + logger.info("Defaulting to use %s for distributed inference", + backend) + + self._verify_args() + self.rank: int = 0 + + @property + def use_ray(self) -> bool: + return self.distributed_executor_backend == "ray" or ( + isinstance(self.distributed_executor_backend, type) + and self.distributed_executor_backend.uses_ray) + + def _verify_args(self) -> None: + # Lazy import to avoid circular import + from vllm.executor.executor_base import ExecutorBase + + if self.distributed_executor_backend not in ( + "ray", "mp", None) and not (isinstance( + self.distributed_executor_backend, type) and issubclass( + self.distributed_executor_backend, ExecutorBase)): + raise ValueError( + "Unrecognized distributed executor backend " + f"{self.distributed_executor_backend}. Supported " + "values are 'ray', 'mp' or custom ExecutorBase subclass.") + if self.use_ray: + from vllm.executor import ray_utils + ray_utils.assert_ray_available() + if is_hip(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on AMD GPUs.") + if self.ray_workers_use_nsight and not self.use_ray: + raise ValueError("Unable to use nsight profiling unless workers " + "run with Ray.") + + +class SchedulerConfig: + """Scheduler configuration. + + Args: + max_num_batched_tokens: Maximum number of tokens to be processed in + a single iteration. + max_num_seqs: Maximum number of sequences to be processed in a single + iteration. + max_model_len: Maximum length of a sequence (including prompt + and generated text). + use_v2_block_manager: Whether to use the BlockSpaceManagerV2 or not. + num_lookahead_slots: The number of slots to allocate per sequence per + step, beyond the known token ids. This is used in speculative + decoding to store KV activations of tokens which may or may not be + accepted. + delay_factor: Apply a delay (of delay factor multiplied by previous + prompt latency) before scheduling next prompt. + enable_chunked_prefill: If True, prefill requests can be chunked based + on the remaining max_num_batched_tokens. + embedding_mode: Whether the running model is for embedding. + preemption_mode: Whether to perform preemption by swapping or + recomputation. If not specified, we determine the mode as follows: + We use recomputation by default since it incurs lower overhead than + swapping. However, when the sequence group has multiple sequences + (e.g., beam search), recomputation is not currently supported. In + such a case, we use swapping instead. + send_delta_data: Private API. If used, scheduler sends delta data to + workers instead of an entire data. It should be enabled only + when SPMD worker architecture is enabled. I.e., + VLLM_USE_RAY_SPMD_WORKER=1 + policy: The scheduling policy to use. "fcfs" (default) or "priority". + """ + + def __init__(self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + use_v2_block_manager: bool = True, + num_lookahead_slots: int = 0, + delay_factor: float = 0.0, + enable_chunked_prefill: bool = False, + embedding_mode: bool = False, + is_multimodal_model: bool = False, + preemption_mode: Optional[str] = None, + num_scheduler_steps: int = 1, + multi_step_stream_outputs: bool = False, + send_delta_data: bool = False, + policy: str = "fcfs") -> None: + if max_num_batched_tokens is None: + if enable_chunked_prefill: + if num_scheduler_steps > 1: + # Multi-step Chunked-Prefill doesn't allow prompt-chunking + # for now. Have max_num_batched_tokens set to max_model_len + # so we don't reject sequences on account of a short + # max_num_batched_tokens. + max_num_batched_tokens = max(max_model_len, 2048) + else: + # It is the values that have the best balance between ITL + # and TTFT on A100. Note it is not optimized for throughput. + max_num_batched_tokens = 512 + else: + # If max_model_len is too short, use 2048 as the default value + # for higher throughput. + max_num_batched_tokens = max(max_model_len, 2048) + + if embedding_mode: + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + self.max_num_batched_tokens = max_num_batched_tokens + + if enable_chunked_prefill: + logger.info( + "Chunked prefill is enabled with max_num_batched_tokens=%d.", + self.max_num_batched_tokens) + + self.max_num_seqs = max_num_seqs + self.max_model_len = max_model_len + self.use_v2_block_manager = use_v2_block_manager + self.num_lookahead_slots = num_lookahead_slots + self.delay_factor = delay_factor + self.chunked_prefill_enabled = enable_chunked_prefill + self.embedding_mode = embedding_mode + self.preemption_mode = preemption_mode + self.num_scheduler_steps = num_scheduler_steps + self.multi_step_stream_outputs = multi_step_stream_outputs + self.send_delta_data = send_delta_data + self.policy = policy + self._verify_args() + + def _verify_args(self) -> None: + if (self.max_num_batched_tokens < self.max_model_len + and not self.chunked_prefill_enabled): + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + if self.num_lookahead_slots < 0: + raise ValueError( + "num_lookahead_slots " + f"({self.num_lookahead_slots}) must be greater than or " + "equal to 0.") + + if self.num_scheduler_steps < 1: + raise ValueError( + "num_scheduler_steps " + f"({self.num_scheduler_steps}) must be greater than or " + "equal to 1.") + + if (not self.use_v2_block_manager \ + and not envs.VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1): + raise ValueError( + "The use of BlockSpaceManagerV1 is deprecated and will " + "be removed in a future release. Please switch to " + "BlockSpaceManagerV2 by setting --use-v2-block-manager to " + "True. If you wish to suppress this error temporarily, " + "you can set the environment variable " + "`VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1. If your use " + "case is not supported in BlockSpaceManagerV2, please " + "file an issue with detailed information.") + + @property + def is_multi_step(self) -> bool: + return self.num_scheduler_steps > 1 + + +class DeviceConfig: + device: Optional[torch.device] + + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + if current_platform.is_cuda_alike(): + self.device_type = "cuda" + elif is_neuron(): + self.device_type = "neuron" + elif is_openvino(): + self.device_type = "openvino" + elif current_platform.is_tpu(): + self.device_type = "tpu" + elif current_platform.is_cpu(): + self.device_type = "cpu" + elif is_xpu(): + self.device_type = "xpu" + else: + raise RuntimeError("Failed to infer device type") + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron", "openvino"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +class SpeculativeConfig: + """Configuration for speculative decoding. + + The configuration is currently specialized to draft-model speculative + decoding with top-1 proposals. + """ + + @staticmethod + def maybe_create_spec_config( + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + target_dtype: str, + speculative_model: Optional[str], + speculative_model_quantization: Optional[str], + speculative_draft_tensor_parallel_size: Optional[int], + num_speculative_tokens: Optional[int], + speculative_disable_mqa_scorer: Optional[bool], + speculative_max_model_len: Optional[int], + enable_chunked_prefill: bool, + use_v2_block_manager: bool, + disable_log_stats: bool, + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: Optional[float], + typical_acceptance_sampler_posterior_alpha: Optional[float], + disable_logprobs: Optional[bool], + ) -> Optional["SpeculativeConfig"]: + """Create a SpeculativeConfig if possible, else return None. + + This function attempts to create a SpeculativeConfig object based on the + provided parameters. If the necessary conditions are met, it returns an + instance of SpeculativeConfig. Otherwise, it returns None. + + Args: + target_model_config (ModelConfig): The configuration of the target + model. + target_parallel_config (ParallelConfig): The parallel configuration + for the target model. + target_dtype (str): The data type used for the target model. + speculative_model (Optional[str]): The name of the speculative + model, if provided. + speculative_model_quantization (Optional[str]): Quantization method + that was used to quantize the speculative model weights. If + None, we assume the model weights are not quantized. + speculative_draft_tensor_parallel_size (Optional[int]): The degree + of the tensor parallelism for the draft model. + num_speculative_tokens (Optional[int]): The number of speculative + tokens, if provided. Will default to the number in the draft + model config if present, otherwise is required. + speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA + scorer for the speculative model and fall back to batch + expansion for scoring. + speculative_max_model_len (Optional[int]): The maximum model len of + the speculative model. Used when testing the ability to skip + speculation for some sequences. + enable_chunked_prefill (bool): Whether vLLM is configured to use + chunked prefill or not. Used for raising an error since its not + yet compatible with spec decode. + use_v2_block_manager (bool): Whether vLLM is configured to use the + v2 block manager or not. Used for raising an error since the v2 + block manager is required with spec decode. + speculative_disable_by_batch_size (Optional[int]): Disable + speculative decoding for new incoming requests when the number + of enqueue requests is larger than this value, if provided. + ngram_prompt_lookup_max (Optional[int]): Max size of ngram token + window, if provided. + ngram_prompt_lookup_min (Optional[int]): Min size of ngram token + window, if provided. + draft_token_acceptance_method (str): The method to use for + accepting draft tokens. This can take two possible + values 'rejection_sampler' and 'typical_acceptance_sampler' + for RejectionSampler and TypicalAcceptanceSampler + respectively. + typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in the target model for it to be + accepted. This threshold is used only when we use the + TypicalAcceptanceSampler for token acceptance. + typical_acceptance_sampler_posterior_alpha (Optional[float]): + A scaling factor for the entropy-based threshold in the + TypicalAcceptanceSampler. + disable_logprobs (Optional[bool]): If set to True, token log + probabilities are not returned during speculative decoding. + If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams. + If not specified, it defaults to True. + + Returns: + Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if + the necessary conditions are met, else None. + """ + + if speculative_model is None: + if num_speculative_tokens is not None: + raise ValueError("num_speculative_tokens was provided without " + "speculative_model.") + return None + + if (speculative_disable_by_batch_size is not None + and speculative_disable_by_batch_size < 2): + raise ValueError("Expect the batch size threshold of disabling " + "speculative decoding is > 1, but got " + f"{speculative_disable_by_batch_size=}") + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if enable_chunked_prefill: + raise ValueError( + "Speculative decoding and chunked prefill are " + f"currently mutually exclusive ({enable_chunked_prefill=}).") + + if not use_v2_block_manager: + raise ValueError( + "Speculative decoding requires usage of the V2 " + "block manager. Enable it with --use-v2-block-manager.") + + # TODO: The user should be able to specify revision/max model len + # for the draft model. It is not currently supported. + draft_revision = None + draft_code_revision = None + draft_quantization = speculative_model_quantization + + if speculative_model == "[ngram]": + if ngram_prompt_lookup_min is None: + ngram_prompt_lookup_min = 1 + if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1: + raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0") + if ngram_prompt_lookup_min < 1: + raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0") + if ngram_prompt_lookup_min > ngram_prompt_lookup_max: + raise ValueError(f"{ngram_prompt_lookup_min=} cannot be " + f"larger than {ngram_prompt_lookup_max=}") + + # TODO: current we still need extract vocab_size from target model + # config, in future, we may try refactor it out, and set + # draft related config as None here. + draft_model_config = target_model_config + draft_parallel_config = target_parallel_config + else: + ngram_prompt_lookup_max = 0 + ngram_prompt_lookup_min = 0 + draft_model_config = ModelConfig( + model=speculative_model, + tokenizer=target_model_config.tokenizer, + tokenizer_mode=target_model_config.tokenizer_mode, + trust_remote_code=target_model_config.trust_remote_code, + dtype=target_model_config.dtype, + seed=target_model_config.seed, + revision=draft_revision, + code_revision=draft_code_revision, + tokenizer_revision=target_model_config.tokenizer_revision, + max_model_len=None, + spec_target_max_model_len=target_model_config.max_model_len, + quantization=draft_quantization, + enforce_eager=target_model_config.enforce_eager, + max_seq_len_to_capture=target_model_config. + max_seq_len_to_capture, + max_logprobs=target_model_config.max_logprobs, + ) + + draft_hf_config = draft_model_config.hf_config + + if (num_speculative_tokens is not None + and hasattr(draft_hf_config, "num_lookahead_tokens")): + draft_hf_config.num_lookahead_tokens = num_speculative_tokens + + n_predict = getattr(draft_hf_config, "n_predict", None) + if n_predict is not None: + if num_speculative_tokens is None: + # Default to max value defined in draft model config. + num_speculative_tokens = n_predict + elif num_speculative_tokens > n_predict: + # Verify provided value doesn't exceed the maximum + # supported by the draft model. + raise ValueError( + "This speculative model supports a maximum of " + f"num_speculative_tokens={n_predict}, but " + f"{num_speculative_tokens=} was provided.") + + draft_model_config.max_model_len = ( + SpeculativeConfig._maybe_override_draft_max_model_len( + speculative_max_model_len, + draft_model_config.max_model_len, + target_model_config.max_model_len, + )) + + draft_parallel_config = ( + SpeculativeConfig.create_draft_parallel_config( + target_parallel_config, + speculative_draft_tensor_parallel_size, draft_hf_config)) + + if num_speculative_tokens is None: + raise ValueError( + "num_speculative_tokens must be provided with " + "speculative_model unless the draft model config contains an " + "n_predict parameter.") + + if typical_acceptance_sampler_posterior_threshold is None: + typical_acceptance_sampler_posterior_threshold = 0.09 + if typical_acceptance_sampler_posterior_alpha is None: + typical_acceptance_sampler_posterior_alpha = 0.3 + if disable_logprobs is None: + disable_logprobs = True + + return SpeculativeConfig( + draft_model_config, + draft_parallel_config, + num_speculative_tokens, + speculative_disable_mqa_scorer, + speculative_disable_by_batch_size, + ngram_prompt_lookup_max, + ngram_prompt_lookup_min, + draft_token_acceptance_method=draft_token_acceptance_method, + typical_acceptance_sampler_posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=\ + typical_acceptance_sampler_posterior_alpha, + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, + ) + + @staticmethod + def _maybe_override_draft_max_model_len( + speculative_max_model_len: Optional[int], + draft_max_model_len: int, + target_max_model_len: int, + ) -> int: + """Determine the max sequence len for the draft model. This is usually + the draft_max_model_len, but may be the target_max_model_len if it is + less than the draft_max_model_len, or may be speculative_max_model_len + if it is specified. + + This is necessary so that sequences do not exceed the capacity of the + draft model or the target model. + + speculative_max_model_len is mainly used for testing that sequences can + skip speculation. + """ + + if speculative_max_model_len is not None: + + if speculative_max_model_len > draft_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {draft_max_model_len=}") + + if speculative_max_model_len > target_max_model_len: + raise ValueError(f"{speculative_max_model_len=} cannot be " + f"larger than {target_max_model_len=}") + + return speculative_max_model_len + + return min( + draft_max_model_len, + target_max_model_len, + ) + + @staticmethod + def create_draft_parallel_config( + target_parallel_config: ParallelConfig, + speculative_draft_tensor_parallel_size: Optional[int], + draft_hf_config: PretrainedConfig, + ) -> ParallelConfig: + """Create a parallel config for use by the draft worker. + + This is mostly a copy of the target parallel config, except the tp_size. + """ + if speculative_draft_tensor_parallel_size is None: + if draft_hf_config.model_type == "mlp_speculator": + speculative_draft_tensor_parallel_size = 1 + if target_parallel_config.tensor_parallel_size > 1: + logger.warning( + "MLPSpeculator cannot currently be run with tp>1; " + "setting speculative_draft_tensor_parallel_size=1") + else: + speculative_draft_tensor_parallel_size = \ + target_parallel_config.tensor_parallel_size + elif speculative_draft_tensor_parallel_size != 1: + # TODO(wooyeon): allow tp values larger than 1 + raise ValueError( + f"{speculative_draft_tensor_parallel_size=} cannot be " + f"other value than 1") + + draft_parallel_config = ParallelConfig( + pipeline_parallel_size=target_parallel_config. + pipeline_parallel_size, + tensor_parallel_size=speculative_draft_tensor_parallel_size, + distributed_executor_backend=target_parallel_config. + distributed_executor_backend, + max_parallel_loading_workers=target_parallel_config. + max_parallel_loading_workers, + disable_custom_all_reduce=target_parallel_config. + disable_custom_all_reduce, + tokenizer_pool_config=target_parallel_config.tokenizer_pool_config, + ray_workers_use_nsight=target_parallel_config. + ray_workers_use_nsight, + placement_group=target_parallel_config.placement_group, + ) + + return draft_parallel_config + + def __init__( + self, + draft_model_config: ModelConfig, + draft_parallel_config: ParallelConfig, + num_speculative_tokens: int, + speculative_disable_mqa_scorer: Optional[bool], + speculative_disable_by_batch_size: Optional[int], + ngram_prompt_lookup_max: Optional[int], + ngram_prompt_lookup_min: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, + disable_log_stats: bool, + ): + """Create a SpeculativeConfig object. + + Args: + draft_model_config: ModelConfig for the draft model. + draft_parallel_config: ParallelConfig for the draft model. + num_speculative_tokens: The number of tokens to sample from the + draft model before scoring with the target model. + speculative_disable_by_batch_size: Disable speculative + decoding for new incoming requests when the number of + enqueue requests is larger than this value. + ngram_prompt_lookup_max: Max size of ngram token window. + ngram_prompt_lookup_min: Min size of ngram token window. + draft_token_acceptance_method (str): The method to use for + accepting draft tokens. This can take two possible + values 'rejection_sampler' and 'typical_acceptance_sampler' + for RejectionSampler and TypicalAcceptanceSampler + respectively. + typical_acceptance_sampler_posterior_threshold (Optional[float]): + A threshold value that sets a lower bound on the posterior + probability of a token in the target model for it to be + accepted. This threshold is used only when we use the + TypicalAcceptanceSampler for token acceptance. + typical_acceptance_sampler_posterior_alpha (Optional[float]): + A scaling factor for the entropy-based threshold in the + TypicalAcceptanceSampler. + disable_logprobs: If set to True, token log probabilities will not + be returned even if requested by sampling parameters. This + reduces latency by skipping logprob calculation in proposal + sampling, target sampling, and after accepted tokens are + determined. If set to False, log probabilities will be + returned. + disable_log_stats: Whether to disable periodic printing of stage + times in speculative decoding. + """ + self.draft_model_config = draft_model_config + self.draft_parallel_config = draft_parallel_config + self.num_speculative_tokens = num_speculative_tokens + self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer + self.speculative_disable_by_batch_size = \ + speculative_disable_by_batch_size + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0 + self.draft_token_acceptance_method = draft_token_acceptance_method + self.typical_acceptance_sampler_posterior_threshold = \ + typical_acceptance_sampler_posterior_threshold + self.typical_acceptance_sampler_posterior_alpha = \ + typical_acceptance_sampler_posterior_alpha + self.disable_logprobs = disable_logprobs + self.disable_log_stats = disable_log_stats + + self._verify_args() + + def _verify_args(self) -> None: + if self.num_speculative_tokens <= 0: + raise ValueError("Expected num_speculative_tokens to be greater " + f"than zero ({self.num_speculative_tokens}).") + + if self.draft_model_config: + self.draft_model_config.verify_with_parallel_config( + self.draft_parallel_config) + # Validate and set draft token acceptance related settings. + + if (self.draft_token_acceptance_method is None): + raise ValueError("draft_token_acceptance_method is not set. " + "Expected values are rejection_sampler or " + "typical_acceptance_sampler.") + + if (self.draft_token_acceptance_method != 'rejection_sampler' + and self.draft_token_acceptance_method != + 'typical_acceptance_sampler'): + raise ValueError( + "Expected draft_token_acceptance_method to be either " + "rejection_sampler or typical_acceptance_sampler. Instead it " + f"is {self.draft_token_acceptance_method}") + + if (self.typical_acceptance_sampler_posterior_threshold < 0 + or self.typical_acceptance_sampler_posterior_alpha < 0): + raise ValueError( + "Expected typical_acceptance_sampler_posterior_threshold " + "and typical_acceptance_sampler_posterior_alpha to be > 0. " + "Instead found " + f"typical_acceptance_sampler_posterior_threshold = " + f"{self.typical_acceptance_sampler_posterior_threshold} and " + f"typical_acceptance_sampler_posterior_alpha = " + f"{self.typical_acceptance_sampler_posterior_alpha}") + + @property + def num_lookahead_slots(self) -> int: + """The number of additional slots the scheduler should allocate per + step, in addition to the slots allocated for each known token. + + This is equal to the number of speculative tokens, as each speculative + token must be scored. + """ + return self.num_speculative_tokens + + def __repr__(self) -> str: + if self.ngram_prompt_lookup_max > 0: + draft_model = "[ngram]" + else: + draft_model = self.draft_model_config.model + num_spec_tokens = self.num_speculative_tokens + return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" + + +@dataclass +class LoRAConfig: + max_lora_rank: int + max_loras: int + fully_sharded_loras: bool = False + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None + + def __post_init__(self): + # Setting the maximum rank to 256 should be able to satisfy the vast + # majority of applications. + possible_max_ranks = (8, 16, 32, 64, 128, 256) + possible_lora_extra_vocab_size = (0, 256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + if model_config.quantization and model_config.quantization not in [ + "awq", "gptq" + ]: + # TODO support marlin + logger.warning("%s quantization is not tested with LoRA yet.", + model_config.quantization) + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if scheduler_config.chunked_prefill_enabled: + raise ValueError("LoRA is not supported with chunked prefill yet.") + + +@dataclass +class PromptAdapterConfig: + max_prompt_adapters: int + max_prompt_adapter_token: int + max_cpu_prompt_adapters: Optional[int] = None + prompt_adapter_dtype: Optional[torch.dtype] = None + + def __post_init__(self): + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_prompt_adapter_token == 0: + raise ValueError("max_prompt_adapter_token must be set.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype in (None, "auto"): + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + +@dataclass +class MultiModalConfig: + """Controls the behavior of multimodal models.""" + + limit_per_prompt: Mapping[str, int] = field(default_factory=dict) + """ + The maximum number of multi-modal input instances allowed per prompt + for each :class:`~vllm.multimodal.MultiModalPlugin`. + """ + + # TODO: Add configs to init vision tower or not. + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +_ROCM_NOT_SUPPORTED_DTYPE: List[str] = [] # + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + if config.model_type == "gemma2": + logger.info( + "For Gemma 2, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], + spec_target_max_model_len: Optional[int] = None, +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + if spec_target_max_model_len is not None: + # If this is a speculative draft model, we use the max model len + # from the target model. + return spec_target_max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + if rope_type == "mrope": + scaling_factor = 1 + else: + if rope_type == "default": + rope_type = "mrope" + scaling_factor = 1 + else: + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +@dataclass +class DecodingConfig: + """Dataclass which contains the decoding strategy of the engine""" + + # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' + guided_decoding_backend: str = 'outlines' + + def __post_init__(self): + valid_guided_backends = ['outlines', 'lm-format-enforcer'] + backend = self.guided_decoding_backend + if backend not in valid_guided_backends: + raise ValueError(f"Invalid guided_decoding_backend '{backend}," + f"must be one of {valid_guided_backends}") + + +@dataclass +class ObservabilityConfig: + """Configuration for observability.""" + otlp_traces_endpoint: Optional[str] = None + + # Collecting detailed timing information for each request can be expensive. + + # If set, collects the model forward time for the request. + collect_model_forward_time: bool = False + + # If set, collects the model execute time for the request. + collect_model_execute_time: bool = False + + def __post_init__(self): + if not is_otel_available() and self.otlp_traces_endpoint is not None: + raise ValueError( + "OpenTelemetry is not available. Unable to configure " + "'otlp_traces_endpoint'. Ensure OpenTelemetry packages are " + f"installed. Original error:\n{otel_import_error_traceback}") + + if ((self.collect_model_forward_time + or self.collect_model_execute_time) + and self.otlp_traces_endpoint is None): + raise ValueError( + "collect_model_forward_time or collect_model_execute_time " + "requires --otlp-traces-endpoint to be set.") + + +@dataclass(frozen=True) +class EngineConfig: + """Dataclass which contains all engine-related configuration. This + simplifies passing around the distinct configurations in the codebase. + """ + + model_config: ModelConfig + cache_config: CacheConfig + parallel_config: ParallelConfig + scheduler_config: SchedulerConfig + device_config: DeviceConfig + load_config: LoadConfig + lora_config: Optional[LoRAConfig] + speculative_config: Optional[SpeculativeConfig] + decoding_config: Optional[DecodingConfig] + observability_config: Optional[ObservabilityConfig] + prompt_adapter_config: Optional[PromptAdapterConfig] + + def __post_init__(self): + """Verify configs are valid & consistent with each other. + """ + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/connections.py b/vllm/connections.py new file mode 100644 index 0000000..e785a0b --- /dev/null +++ b/vllm/connections.py @@ -0,0 +1,167 @@ +from pathlib import Path +from typing import Mapping, MutableMapping, Optional +from urllib.parse import urlparse + +import aiohttp +import requests + +from vllm.version import __version__ as VLLM_VERSION + + +class HTTPConnection: + """Helper class to send HTTP requests.""" + + def __init__(self, *, reuse_client: bool = True) -> None: + super().__init__() + + self.reuse_client = reuse_client + + self._sync_client: Optional[requests.Session] = None + self._async_client: Optional[aiohttp.ClientSession] = None + + def get_sync_client(self) -> requests.Session: + if self._sync_client is None or not self.reuse_client: + self._sync_client = requests.Session() + + return self._sync_client + + # NOTE: We intentionally use an async function even though it is not + # required, so that the client is only accessible inside async event loop + async def get_async_client(self) -> aiohttp.ClientSession: + if self._async_client is None or not self.reuse_client: + self._async_client = aiohttp.ClientSession() + + return self._async_client + + def _validate_http_url(self, url: str): + parsed_url = urlparse(url) + + if parsed_url.scheme not in ("http", "https"): + raise ValueError("Invalid HTTP URL: A valid HTTP URL " + "must have scheme 'http' or 'https'.") + + def _headers(self, **extras: str) -> MutableMapping[str, str]: + return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} + + def get_response( + self, + url: str, + *, + stream: bool = False, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = self.get_sync_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + stream=stream, + timeout=timeout) + + async def get_async_response( + self, + url: str, + *, + timeout: Optional[float] = None, + extra_headers: Optional[Mapping[str, str]] = None, + ): + self._validate_http_url(url) + + client = await self.get_async_client() + extra_headers = extra_headers or {} + + return client.get(url, + headers=self._headers(**extra_headers), + timeout=timeout) + + def get_bytes(self, url: str, *, timeout: Optional[float] = None) -> bytes: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.content + + async def async_get_bytes( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> bytes: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.read() + + def get_text(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.text + + async def async_get_text( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.text() + + def get_json(self, url: str, *, timeout: Optional[float] = None) -> str: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + return r.json() + + async def async_get_json( + self, + url: str, + *, + timeout: Optional[float] = None, + ) -> str: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + return await r.json() + + def download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + with self.get_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + for chunk in r.iter_content(chunk_size): + f.write(chunk) + + return save_path + + async def async_download_file( + self, + url: str, + save_path: Path, + *, + timeout: Optional[float] = None, + chunk_size: int = 128, + ) -> Path: + async with await self.get_async_response(url, timeout=timeout) as r: + r.raise_for_status() + + with save_path.open("wb") as f: + async for chunk in r.content.iter_chunked(chunk_size): + f.write(chunk) + + return save_path + + +global_http_connection = HTTPConnection() +"""The global :class:`HTTPConnection` instance used by vLLM.""" diff --git a/vllm/core/__init__.py b/vllm/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/__pycache__/__init__.cpython-310.pyc b/vllm/core/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9f6eb9f Binary files /dev/null and b/vllm/core/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/block_manager_v1.cpython-310.pyc b/vllm/core/__pycache__/block_manager_v1.cpython-310.pyc new file mode 100644 index 0000000..8c402a8 Binary files /dev/null and b/vllm/core/__pycache__/block_manager_v1.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/block_manager_v2.cpython-310.pyc b/vllm/core/__pycache__/block_manager_v2.cpython-310.pyc new file mode 100644 index 0000000..dd98dec Binary files /dev/null and b/vllm/core/__pycache__/block_manager_v2.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/evictor_v1.cpython-310.pyc b/vllm/core/__pycache__/evictor_v1.cpython-310.pyc new file mode 100644 index 0000000..c3ea008 Binary files /dev/null and b/vllm/core/__pycache__/evictor_v1.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/evictor_v2.cpython-310.pyc b/vllm/core/__pycache__/evictor_v2.cpython-310.pyc new file mode 100644 index 0000000..6f0fe90 Binary files /dev/null and b/vllm/core/__pycache__/evictor_v2.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/interfaces.cpython-310.pyc b/vllm/core/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000..444f4a8 Binary files /dev/null and b/vllm/core/__pycache__/interfaces.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/placeholder_block_space_manager.cpython-310.pyc b/vllm/core/__pycache__/placeholder_block_space_manager.cpython-310.pyc new file mode 100644 index 0000000..34b1c47 Binary files /dev/null and b/vllm/core/__pycache__/placeholder_block_space_manager.cpython-310.pyc differ diff --git a/vllm/core/__pycache__/scheduler.cpython-310.pyc b/vllm/core/__pycache__/scheduler.cpython-310.pyc new file mode 100644 index 0000000..ac0dc40 Binary files /dev/null and b/vllm/core/__pycache__/scheduler.cpython-310.pyc differ diff --git a/vllm/core/block/__init__.py b/vllm/core/block/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/block/__pycache__/__init__.cpython-310.pyc b/vllm/core/block/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..c456286 Binary files /dev/null and b/vllm/core/block/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/block_table.cpython-310.pyc b/vllm/core/block/__pycache__/block_table.cpython-310.pyc new file mode 100644 index 0000000..5600c63 Binary files /dev/null and b/vllm/core/block/__pycache__/block_table.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/common.cpython-310.pyc b/vllm/core/block/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000..5ac6490 Binary files /dev/null and b/vllm/core/block/__pycache__/common.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/cpu_gpu_block_allocator.cpython-310.pyc b/vllm/core/block/__pycache__/cpu_gpu_block_allocator.cpython-310.pyc new file mode 100644 index 0000000..b928f93 Binary files /dev/null and b/vllm/core/block/__pycache__/cpu_gpu_block_allocator.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/interfaces.cpython-310.pyc b/vllm/core/block/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000..5b72ccf Binary files /dev/null and b/vllm/core/block/__pycache__/interfaces.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/naive_block.cpython-310.pyc b/vllm/core/block/__pycache__/naive_block.cpython-310.pyc new file mode 100644 index 0000000..859d45a Binary files /dev/null and b/vllm/core/block/__pycache__/naive_block.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/prefix_caching_block.cpython-310.pyc b/vllm/core/block/__pycache__/prefix_caching_block.cpython-310.pyc new file mode 100644 index 0000000..ccc5f07 Binary files /dev/null and b/vllm/core/block/__pycache__/prefix_caching_block.cpython-310.pyc differ diff --git a/vllm/core/block/__pycache__/utils.cpython-310.pyc b/vllm/core/block/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..c5f9386 Binary files /dev/null and b/vllm/core/block/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/core/block/block_table.py b/vllm/core/block/block_table.py new file mode 100644 index 0000000..d10cb29 --- /dev/null +++ b/vllm/core/block/block_table.py @@ -0,0 +1,374 @@ +import math +from typing import List, Optional + +from vllm.core.block.common import BlockList +from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator +from vllm.utils import Device, cdiv, chunk_list + + +class BlockTable: + """A class to manage blocks for a specific sequence. + + The BlockTable maps a sequence of tokens to a list of blocks, where each + block represents a contiguous memory allocation for a portion of the + sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is + responsible for allocating and freeing memory for the blocks. + + Args: + block_size (int): The maximum number of tokens that can be stored in a + single block. + block_allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]], optional): An optional list of existing + blocks to initialize the BlockTable with. If not provided, an empty + BlockTable is created. + max_block_sliding_window (Optional[int], optional): The number of + blocks to keep around for each sequance. If None, all blocks + are kept (eg., when sliding window is not used). + It should at least fit the sliding window size of the model. + + Attributes: + _block_size (int): The maximum number of tokens that can be stored in a + single block. + _allocator (DeviceAwareBlockAllocator): The block allocator used to + manage memory for the blocks. + _blocks (Optional[List[Block]]): The list of blocks managed by this + BlockTable. + _num_full_slots (int): The number of tokens currently stored in the + blocks. + """ + + def __init__( + self, + block_size: int, + block_allocator: DeviceAwareBlockAllocator, + _blocks: Optional[List[Block]] = None, + max_block_sliding_window: Optional[int] = None, + ): + self._block_size = block_size + self._allocator = block_allocator + if _blocks is None: + _blocks = [] + self._blocks: BlockList = BlockList(_blocks) + + self._max_block_sliding_window = max_block_sliding_window + self._num_full_slots = self._get_num_token_ids() + + @staticmethod + def get_num_required_blocks(token_ids: List[int], + block_size: int, + num_lookahead_slots: int = 0) -> int: + """Calculates the minimum number of blocks required to store a given + sequence of token IDs along with any look-ahead slots that may be + required (like in multi-step + chunked-prefill). + + This assumes worst-case scenario, where every block requires a new + allocation (e.g. ignoring prefix caching). + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + block_size (int): The maximum number of tokens that can be stored in + a single block. + num_lookahead_slots (int): look-ahead slots that the sequence may + require. + + Returns: + int: The minimum number of blocks required to store the given + sequence of token IDs along with any required look-ahead slots. + """ + return cdiv(len(token_ids) + num_lookahead_slots, block_size) + + def allocate(self, + token_ids: List[int], + device: Device = Device.GPU) -> None: + """Allocates memory blocks for storing the given sequence of token IDs. + + This method allocates the required number of blocks to store the given + sequence of token IDs. + + Args: + token_ids (List[int]): The sequence of token IDs to be stored. + device (Device, optional): The device on which the blocks should be + allocated. Defaults to Device.GPU. + """ + assert not self._is_allocated + assert token_ids + blocks = self._allocate_blocks_for_token_ids(prev_block=None, + token_ids=token_ids, + device=device) + self.update(blocks) + self._num_full_slots = len(token_ids) + + def update(self, blocks: List[Block]) -> None: + """Resets the table to the newly provided blocks + (with their corresponding block ids) + """ + self._blocks.update(blocks) + + def append_token_ids(self, + token_ids: List[int], + num_lookahead_slots: int = 0, + num_computed_slots: Optional[int] = None) -> None: + """Appends a sequence of token IDs to the existing blocks in the + BlockTable. + + This method appends the given sequence of token IDs to the existing + blocks in the BlockTable. If there is not enough space in the existing + blocks, new blocks are allocated using the `ensure_num_empty_slots` + method to accommodate the additional tokens. + + The token IDs are divided into chunks of size `block_size` (except for + the first chunk, which may be smaller), and each chunk is appended to a + separate block. + + Args: + token_ids (List[int]): The sequence of token IDs to be appended. + num_computed_slots (Optional[int]): The number of KV cache slots + that are already filled (computed). + When sliding window is enabled, this is used to compute how many + blocks to drop at the front of the sequence. + Without sliding window, None can be passed. + Without chunked prefill, it should be the same as + _num_full_slots. + """ + assert self._is_allocated, "no blocks have been allocated" + assert len(self._blocks) > 0 + + # Drop blocks that are no longer needed due to sliding window + if self._max_block_sliding_window is not None: + null_block = self._allocator.allocate_or_get_null_block() + assert num_computed_slots is not None + end_block_idx = (num_computed_slots // + self._block_size) - self._max_block_sliding_window + for idx in range(0, end_block_idx): + b = self._blocks[idx] + if b is not null_block: + self._allocator.free(b) + self._blocks[idx] = null_block + + # Ensure there are enough empty slots for the new tokens plus + # lookahead slots + self.ensure_num_empty_slots(num_empty_slots=len(token_ids) + + num_lookahead_slots) + + # Update the blocks with the new tokens + first_block_idx = self._num_full_slots // self._block_size + token_blocks = self._chunk_token_blocks_for_append(token_ids) + + for i, token_block in enumerate(token_blocks): + self._blocks.append_token_ids(first_block_idx + i, token_block) + + self._num_full_slots += len(token_ids) + + def ensure_num_empty_slots(self, num_empty_slots: int) -> None: + """Ensures that the BlockTable has at least the specified number of + empty slots available. + + This method checks if the BlockTable has enough empty slots (i.e., + available space) to accommodate the requested number of tokens. If not, + it allocates additional blocks on the GPU to ensure that the required + number of empty slots is available. + + Args: + num_empty_slots (int): The minimum number of empty slots required. + """ + # Currently the block table only supports + # appending tokens to GPU blocks. + device = Device.GPU + assert self._is_allocated + + if self._num_empty_slots >= num_empty_slots: + return + + slots_to_allocate = num_empty_slots - self._num_empty_slots + blocks_to_allocate = cdiv(slots_to_allocate, self._block_size) + + for _ in range(blocks_to_allocate): + assert len(self._blocks) > 0 + self._blocks.append( + self._allocator.allocate_mutable_block( + prev_block=self._blocks[-1], device=device)) + + def fork(self) -> "BlockTable": + """Creates a new BlockTable instance with a copy of the blocks from the + current instance. + + This method creates a new BlockTable instance with the same block size, + block allocator, and a copy of the blocks from the current instance. The + new BlockTable has its own independent set of blocks, but shares the + same underlying memory allocation with the original BlockTable. + + Returns: + BlockTable: A new BlockTable instance with a copy of the blocks from + the current instance. + """ + assert self._is_allocated + assert len(self._blocks) > 0 + forked_blocks = self._allocator.fork(self._blocks[-1]) + return BlockTable( + block_size=self._block_size, + block_allocator=self._allocator, + _blocks=forked_blocks, + max_block_sliding_window=self._max_block_sliding_window, + ) + + def free(self) -> None: + """Frees the memory occupied by the blocks in the BlockTable. + + This method iterates over all the blocks in the `_blocks` list and calls + the `free` method of the `_allocator` object to release the memory + occupied by each block. After freeing all the blocks, the `_blocks` list + is set to `None`. + """ + for block in self.blocks: + self._allocator.free(block) + self._blocks.reset() + + @property + def physical_block_ids(self) -> List[int]: + """Returns a list of physical block indices for the blocks in the + BlockTable. + + This property returns a list of integers, where each integer represents + the physical block index of a corresponding block in the `_blocks` list. + The physical block index is a unique identifier for the memory location + occupied by the block. + + Returns: + List[int]: A list of physical block indices for the blocks in the + BlockTable. + """ + return self._blocks.ids() + + def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]: + """Get the number of "unseen" tokens in the sequence. + + Unseen tokens are tokens in the sequence corresponding to this block + table, but are not yet appended to this block table. + + Args: + sequence_token_ids (List[int]): The list of token ids in the + sequence. + + Returns: + List[int]: The postfix of sequence_token_ids that has not yet been + appended to the block table. + """ + + # Since the block table is append-only, the unseen token ids are the + # ones after the appended ones. + return sequence_token_ids[self.num_full_slots:] + + def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> List[Block]: + blocks: List[Block] = [] + + block_token_ids = [] + tail_token_ids = [] + for cur_token_ids in chunk_list(token_ids, self._block_size): + if len(cur_token_ids) == self._block_size: + block_token_ids.append(cur_token_ids) + else: + tail_token_ids.append(cur_token_ids) + + if block_token_ids: + blocks.extend( + self._allocator.allocate_immutable_blocks( + prev_block, block_token_ids=block_token_ids, + device=device)) + prev_block = blocks[-1] + + if tail_token_ids: + assert len(tail_token_ids) == 1 + cur_token_ids = tail_token_ids[0] + + block = self._allocator.allocate_mutable_block( + prev_block=prev_block, device=device) + block.append_token_ids(cur_token_ids) + + blocks.append(block) + + return blocks + + def _get_all_token_ids(self) -> List[int]: + # NOTE: This function is O(seq_len); use sparingly. + token_ids: List[int] = [] + + if not self._is_allocated: + return token_ids + + for block in self.blocks: + token_ids.extend(block.token_ids) + + return token_ids + + def _get_num_token_ids(self) -> int: + res = 0 + for block in self.blocks: + res += len(block.token_ids) + + return res + + @property + def _is_allocated(self) -> bool: + return len(self._blocks) > 0 + + @property + def blocks(self) -> List[Block]: + return self._blocks.list() + + @property + def _num_empty_slots(self) -> int: + assert self._is_allocated + return len(self._blocks) * self._block_size - self._num_full_slots + + @property + def num_full_slots(self) -> int: + """Returns the total number of tokens currently stored in the + BlockTable. + + Returns: + int: The total number of tokens currently stored in the BlockTable. + """ + return self._num_full_slots + + def get_num_blocks_touched_by_append_slots( + self, token_ids: List[int], num_lookahead_slots: int) -> int: + """Determine how many blocks will be "touched" by appending the token + ids. + + This is required for the scheduler to determine whether a sequence can + continue generation, or if it must be preempted. + """ + # Math below is equivalent to: + # all_token_ids = token_ids + [-1] * num_lookahead_slots + # token_blocks = self._chunk_token_blocks_for_append(all_token_ids) + # return len(token_blocks) + + num_token_ids = len(token_ids) + num_lookahead_slots + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + num_token_blocks = (1 + math.ceil( + (num_token_ids - first_chunk_size) / self._block_size)) + return num_token_blocks + + def _chunk_token_blocks_for_append( + self, token_ids: List[int]) -> List[List[int]]: + """Split the token ids into block-sized chunks so they can be easily + appended to blocks. The first such "token block" may have less token ids + than the block size, since the last allocated block may be partially + full. + + If no token ids are provided, then no chunks are returned. + """ + + if not token_ids: + return [] + + first_chunk_size = self._block_size - (self._num_full_slots % + self._block_size) + token_blocks = [token_ids[:first_chunk_size]] + token_blocks.extend( + chunk_list(token_ids[first_chunk_size:], self._block_size)) + return token_blocks diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py new file mode 100644 index 0000000..eb190ad --- /dev/null +++ b/vllm/core/block/common.py @@ -0,0 +1,360 @@ +from collections import deque +from dataclasses import dataclass +from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple + +from vllm.core.block.interfaces import Block, BlockAllocator + +BlockId = int +RefCount = int + + +class RefCounterProtocol(Protocol): + + def incr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def decr(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + def get(self, block_id: BlockId) -> RefCount: + raise NotImplementedError + + +class RefCounter(RefCounterProtocol): + """A class for managing reference counts for a set of block indices. + + The RefCounter class maintains a dictionary that maps block indices to their + corresponding reference counts. It provides methods to increment, decrement, + and retrieve the reference count for a given block index. + + Args: + all_block_indices (Iterable[BlockId]): An iterable of block indices + to initialize the reference counter with. + """ + + def __init__(self, all_block_indices: Iterable[BlockId]): + deduped = set(all_block_indices) + self._refcounts: Dict[BlockId, + RefCount] = {index: 0 + for index in deduped} + + def incr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + pre_incr_refcount = self._refcounts[block_id] + + assert pre_incr_refcount >= 0 + + post_incr_refcount = pre_incr_refcount + 1 + self._refcounts[block_id] = post_incr_refcount + return post_incr_refcount + + def decr(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + refcount = self._refcounts[block_id] + + assert refcount > 0 + refcount -= 1 + + self._refcounts[block_id] = refcount + + return refcount + + def get(self, block_id: BlockId) -> RefCount: + assert block_id in self._refcounts + return self._refcounts[block_id] + + def as_readonly(self) -> "ReadOnlyRefCounter": + return ReadOnlyRefCounter(self) + + +class ReadOnlyRefCounter(RefCounterProtocol): + """A read-only view of the RefCounter class. + + The ReadOnlyRefCounter class provides a read-only interface to access the + reference counts maintained by a RefCounter instance. It does not allow + modifications to the reference counts. + + Args: + refcounter (RefCounter): The RefCounter instance to create a read-only + view for. + """ + + def __init__(self, refcounter: RefCounter): + self._refcounter = refcounter + + def incr(self, block_id: BlockId) -> RefCount: + raise ValueError("Incr not allowed") + + def decr(self, block_id: BlockId) -> RefCount: + raise ValueError("Decr not allowed") + + def get(self, block_id: BlockId) -> RefCount: + return self._refcounter.get(block_id) + + +class CopyOnWriteTracker: + """A class for tracking and managing copy-on-write operations for blocks. + + The CopyOnWriteTracker class maintains a mapping of source block indices to + their corresponding copy-on-write destination block indices. It works in + conjunction with a RefCounter. + + Args: + refcounter (RefCounter): The reference counter used to track block + reference counts. + """ + + def __init__(self, refcounter: RefCounterProtocol): + self._copy_on_writes: List[Tuple[BlockId, BlockId]] = [] + self._refcounter = refcounter + + def is_appendable(self, block: Block) -> bool: + """Checks if the block is shared or not. If shared, then it cannot + be appended and needs to be duplicated via copy-on-write + """ + block_id = block.block_id + if block_id is None: + return True + + refcount = self._refcounter.get(block_id) + return refcount <= 1 + + def record_cow(self, src_block_id: Optional[BlockId], + trg_block_id: Optional[BlockId]) -> None: + """Records a copy-on-write operation from source to target block id + Args: + src_block_id (BlockId): The source block id from which to copy + the data + trg_block_id (BlockId): The target block id to which the data + is copied + """ + assert src_block_id is not None + assert trg_block_id is not None + self._copy_on_writes.append((src_block_id, trg_block_id)) + + def clear_cows(self) -> List[Tuple[BlockId, BlockId]]: + """Clears the copy-on-write tracking information and returns the current + state. + + This method returns a list mapping source block indices to + destination block indices for the current copy-on-write operations. + It then clears the internal tracking information. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices for the + current copy-on-write operations. + """ + cows = self._copy_on_writes + self._copy_on_writes = [] + return cows + + +class BlockPool: + """Used to pre-allocate block objects, in order to avoid excessive python + object allocations/deallocations. + The pool starts from "pool_size" objects and will increase to more objects + if necessary + + Note that multiple block objects may point to the same physical block id, + which is why this pool is needed, so that it will be easier to support + prefix caching and more complicated sharing of physical blocks. + """ + + def __init__(self, block_size: int, create_block: Block.Factory, + allocator: BlockAllocator, pool_size: int): + self._block_size = block_size + self._create_block = create_block + self._allocator = allocator + self._pool_size = pool_size + assert self._pool_size >= 0 + + self._free_ids: Deque[int] = deque(range(self._pool_size)) + self._pool = [] + for i in range(self._pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None)) + + def increase_pool(self): + """Doubles the internal pool size + """ + cur_pool_size = self._pool_size + new_pool_size = cur_pool_size * 2 + self._pool_size = new_pool_size + + self._free_ids += deque(range(cur_pool_size, new_pool_size)) + + for i in range(cur_pool_size, new_pool_size): + self._pool.append( + self._create_block(prev_block=None, + token_ids=[], + block_size=self._block_size, + allocator=self._allocator, + block_id=None)) + + def init_block(self, prev_block: Optional[Block], token_ids: List[int], + block_size: int, physical_block_id: Optional[int]) -> Block: + if len(self._free_ids) == 0: + self.increase_pool() + assert len(self._free_ids) > 0 + + pool_id = self._free_ids.popleft() + + block = self._pool[pool_id] + block.__init__( # type: ignore[misc] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + allocator=block._allocator, # type: ignore[attr-defined] + block_id=physical_block_id) + block.pool_id = pool_id # type: ignore[attr-defined] + return block + + def free_block(self, block: Block) -> None: + self._free_ids.appendleft(block.pool_id) # type: ignore[attr-defined] + + +class BlockList: + """This class is an optimization to allow fast-access to physical + block ids. It maintains a block id list that is updated with the + block list and this avoids the need to reconstruct the block id + list on every iteration of the block manager + """ + + def __init__(self, blocks: List[Block]): + self._blocks: List[Block] = [] + self._block_ids: List[int] = [] + + self.update(blocks) + + def _add_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_ids.append(block_id) + + def _update_block_id(self, block_index: int, + new_block_id: Optional[BlockId]) -> None: + assert new_block_id is not None + self._block_ids[block_index] = new_block_id + + def update(self, blocks: List[Block]): + self._blocks = blocks + + # Cache block ids for fast query + self._block_ids = [] + for block in self._blocks: + self._add_block_id(block.block_id) + + def append_token_ids(self, block_index: int, token_ids: List[int]) -> None: + block = self._blocks[block_index] + prev_block_id = block.block_id + + block.append_token_ids(token_ids) + + # CoW or promotion may update the internal block_id + if prev_block_id != block.block_id: + self._update_block_id(block_index, block.block_id) + + def append(self, new_block: Block): + self._blocks.append(new_block) + self._add_block_id(new_block.block_id) + + def __len__(self) -> int: + return len(self._blocks) + + def __getitem__(self, block_index: int) -> Block: + return self._blocks[block_index] + + def __setitem__(self, block_index: int, new_block: Block) -> None: + self._blocks[block_index] = new_block + self._update_block_id(block_index, new_block.block_id) + + def reset(self): + self._blocks = [] + self._block_ids = [] + + def list(self) -> List[Block]: + return self._blocks + + def ids(self) -> List[int]: + return self._block_ids + + +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + +def get_all_blocks_recursively(last_block: Block) -> List[Block]: + """Retrieves all the blocks in a sequence starting from the last block. + + This function recursively traverses the sequence of blocks in reverse order, + starting from the given last block, and returns a list of all the blocks in + the sequence. + + Args: + last_block (Block): The last block in the sequence. + + Returns: + List[Block]: A list of all the blocks in the sequence, in the order they + appear. + """ + + def recurse(block: Block, lst: List[Block]) -> None: + if block.prev_block is not None: + recurse(block.prev_block, lst) + lst.append(block) + + all_blocks: List[Block] = [] + recurse(last_block, all_blocks) + return all_blocks diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py new file mode 100644 index 0000000..6eda5f9 --- /dev/null +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -0,0 +1,404 @@ +from typing import Dict, FrozenSet, List, Optional, Tuple + +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, + DeviceAwareBlockAllocator) +from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator +from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator +from vllm.utils import Device + + +class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): + """A block allocator that can allocate blocks on both CPU and GPU memory. + + This class implements the `DeviceAwareBlockAllocator` interface and provides + functionality for allocating and managing blocks of memory on both CPU and + GPU devices. + + The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU + blocks, and allows for allocation, deallocation, forking, and swapping of + blocks across these memory pools. + """ + + @staticmethod + def create( + allocator_type: str, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + ) -> DeviceAwareBlockAllocator: + """Creates a CpuGpuBlockAllocator instance with the specified + configuration. + + This static method creates and returns a CpuGpuBlockAllocator instance + based on the provided parameters. It initializes the CPU and GPU block + allocators with the specified number of blocks, block size, and + allocator type. + + Args: + allocator_type (str): The type of block allocator to use for CPU + and GPU blocks. Currently supported values are "naive" and + "prefix_caching". + num_gpu_blocks (int): The number of blocks to allocate for GPU + memory. + num_cpu_blocks (int): The number of blocks to allocate for CPU + memory. + block_size (int): The size of each block in number of tokens. + + Returns: + DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the + specified configuration. + + Notes: + - The block IDs are assigned contiguously, with GPU block IDs coming + before CPU block IDs. + """ + block_ids = list(range(num_gpu_blocks + num_cpu_blocks)) + gpu_block_ids = block_ids[:num_gpu_blocks] + cpu_block_ids = block_ids[num_gpu_blocks:] + + if allocator_type == "naive": + gpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator: BlockAllocator = NaiveBlockAllocator( + create_block=NaiveBlock, # type: ignore + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + elif allocator_type == "prefix_caching": + gpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_gpu_blocks, + block_size=block_size, + block_ids=gpu_block_ids, + ) + + cpu_allocator = PrefixCachingBlockAllocator( + num_blocks=num_cpu_blocks, + block_size=block_size, + block_ids=cpu_block_ids, + ) + else: + raise ValueError(f"Unknown allocator type {allocator_type=}") + + return CpuGpuBlockAllocator( + cpu_block_allocator=cpu_allocator, + gpu_block_allocator=gpu_allocator, + ) + + def __init__(self, cpu_block_allocator: BlockAllocator, + gpu_block_allocator: BlockAllocator): + assert not ( + cpu_block_allocator.all_block_ids + & gpu_block_allocator.all_block_ids + ), "cpu and gpu block allocators can't have intersection of block ids" + + self._allocators = { + Device.CPU: cpu_block_allocator, + Device.GPU: gpu_block_allocator, + } + + self._swap_mapping: Dict[int, int] = {} + self._null_block: Optional[Block] = None + + self._block_ids_to_allocator: Dict[int, BlockAllocator] = {} + for _, allocator in self._allocators.items(): + for block_id in allocator.all_block_ids: + self._block_ids_to_allocator[block_id] = allocator + + def allocate_or_get_null_block(self) -> Block: + if self._null_block is None: + self._null_block = NullBlock( + self.allocate_mutable_block(None, Device.GPU)) + return self._null_block + + def allocate_mutable_block(self, prev_block: Optional[Block], + device: Device) -> Block: + """Allocates a new mutable block on the specified device. + + Args: + prev_block (Optional[Block]): The previous block to in the sequence. + Used for prefix hashing. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated mutable block. + """ + return self._allocators[device].allocate_mutable_block(prev_block) + + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device) -> List[Block]: + """Allocates a new group of immutable blocks with the provided block + token IDs on the specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + block_token_ids (List[int]): The list of block token IDs to be + stored in the new blocks. + device (Device): The device on which to allocate the new block. + + Returns: + List[Block]: The newly allocated list of immutable blocks + containing the provided block token IDs. + """ + return self._allocators[device].allocate_immutable_blocks( + prev_block, block_token_ids) + + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> Block: + """Allocates a new immutable block with the provided token IDs on the + specified device. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + Used for prefix hashing. + token_ids (List[int]): The list of token IDs to be stored in the new + block. + device (Device): The device on which to allocate the new block. + + Returns: + Block: The newly allocated immutable block containing the provided + token IDs. + """ + return self._allocators[device].allocate_immutable_block( + prev_block, token_ids) + + def free(self, block: Block) -> None: + """Frees the memory occupied by the given block. + + Args: + block (Block): The block to be freed. + """ + # Null block should never be freed + if isinstance(block, NullBlock): + return + block_id = block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + allocator.free(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: A new list of blocks that shares the same memory as the + original sequence. + """ + # do not attempt to fork the null block + assert not isinstance(last_block, NullBlock) + block_id = last_block.block_id + assert block_id is not None + allocator = self._block_ids_to_allocator[block_id] + return allocator.fork(last_block) + + def get_num_free_blocks(self, device: Device) -> int: + """Returns the number of free blocks available on the specified device. + + Args: + device (Device): The device for which to query the number of free + blocks. AssertionError is raised if None is passed. + + Returns: + int: The number of free blocks available on the specified device. + """ + return self._allocators[device].get_num_free_blocks() + + def get_num_total_blocks(self, device: Device) -> int: + return self._allocators[device].get_num_total_blocks() + + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + """Returns the zero-offset block id on certain device given the + absolute block id. + + Args: + device (Device): The device for which to query relative block id. + absolute_id (int): The absolute block id for the block in + whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return self._allocators[device].get_physical_block_id(absolute_id) + + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + """Execute the swap for the given blocks from source_device + on to dest_device, save the current swap mapping and append + them to the accumulated `self._swap_mapping` for each + scheduling move. + + Args: + blocks: List of blocks to be swapped. + src_device (Device): Device to swap the 'blocks' from. + dst_device (Device): Device to swap the 'blocks' to. + + Returns: + Dict[int, int]: Swap mapping from source_device + on to dest_device. + """ + src_block_ids = [block.block_id for block in blocks] + self._allocators[src_device].swap_out(blocks) + self._allocators[dst_device].swap_in(blocks) + dst_block_ids = [block.block_id for block in blocks] + + current_swap_mapping: Dict[int, int] = {} + for src_block_id, dst_block_id in zip(src_block_ids, dst_block_ids): + if src_block_id is not None and dst_block_id is not None: + self._swap_mapping[src_block_id] = dst_block_id + current_swap_mapping[src_block_id] = dst_block_id + return current_swap_mapping + + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + + Args: + blocks: List of blocks to be swapped. + device (Device): Device to swap the 'blocks' on. + + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks on to the 'device'. + Non full blocks are ignored when deciding the number + of blocks to touch. + """ + return self._allocators[device].get_num_full_blocks_touched(blocks) + + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + """Clears the copy-on-write (CoW) state and returns the mapping of + source to destination block IDs. + + Returns: + List[Tuple[int, int]]: A list mapping source block IDs to + destination block IDs. + """ + # CoW only supported on GPU + device = Device.GPU + return self._allocators[device].clear_copy_on_writes() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_accessed(block_ids, now) + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as accessed, only use for prefix caching.""" + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].mark_blocks_as_computed(block_ids) + + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_computed_block_ids( + prev_computed_block_ids, block_ids, skip_last_block_id) + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + # Prefix caching only supported on GPU. + device = Device.GPU + return self._allocators[device].get_common_computed_block_ids( + computed_seq_block_ids) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return frozenset(self._block_ids_to_allocator.keys()) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: + """Returns and clears the mapping of source to destination block IDs. + Will be called after every swapping operations for now, and after every + schedule when BlockManagerV2 become default. Currently not useful. + + Returns: + List[Tuple[int, int]]: A mapping of source to destination block IDs. + """ + mapping = self._swap_mapping.copy() + self._swap_mapping.clear() + return list(mapping.items()) + + +class NullBlock(Block): + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + This implementation just wraps an ordinary block and prevents it from + being modified. It also allows for testing if a block is NullBlock + via isinstance(). + """ + + def __init__(self, proxy: Block): + super().__init__() + self._proxy = proxy + + def append_token_ids(self, token_ids: List[BlockId]): + raise ValueError("null block should not be modified") + + @property + def block_id(self): + return self._proxy.block_id + + @block_id.setter + def block_id(self, value: Optional[BlockId]): + raise ValueError("null block should not be modified") + + @property + def token_ids(self) -> List[BlockId]: + return self._proxy.token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for null block") + + @property + def num_empty_slots(self) -> BlockId: + return self._proxy.num_empty_slots + + @property + def is_full(self): + return self._proxy.is_full + + @property + def prev_block(self): + return self._proxy.prev_block + + @property + def computed(self): + return self._proxy.computed + + @computed.setter + def computed(self, value): + self._proxy.computed = value + + @property + def last_accessed(self) -> float: + return self._proxy.last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._proxy.last_accessed = last_accessed_ts + + @property + def content_hash(self): + return self._proxy.content_hash diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py new file mode 100644 index 0000000..72bbab1 --- /dev/null +++ b/vllm/core/block/interfaces.py @@ -0,0 +1,286 @@ +from abc import ABC, abstractmethod +from typing import Dict, FrozenSet, List, Optional, Protocol, Tuple + +from vllm.utils import Device + +BlockId = int + + +class Block(ABC): + + @abstractmethod + def append_token_ids(self, token_ids: List[int]) -> None: + pass + + @property + @abstractmethod + def block_id(self) -> Optional[int]: + pass + + @block_id.setter + @abstractmethod + def block_id(self, value: Optional[int]) -> None: + """NOTE: Do not use this API outside Block.""" + self._block_id = value + + @property + @abstractmethod + def token_ids(self) -> List[int]: + pass + + @property + @abstractmethod + def num_tokens_total(self) -> int: + """The number of tokens till the current block (inclusive) + """ + pass + + @property + @abstractmethod + def num_empty_slots(self) -> int: + pass + + @property + @abstractmethod + def is_full(self) -> bool: + pass + + @property + @abstractmethod + def prev_block(self) -> Optional["Block"]: + pass + + @property + @abstractmethod + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + @abstractmethod + def computed(self, value) -> bool: + """Should be only used by PrefixCacingAllocator""" + raise NotImplementedError + + @property + @abstractmethod + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + @abstractmethod + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + class Factory(Protocol): + + @abstractmethod + def __call__( + self, + prev_block: Optional["Block"], + token_ids: List[int], + block_size: int, + allocator: "BlockAllocator", + block_id: Optional[int] = None, + ) -> "Block": + pass + + @property + @abstractmethod + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined or not supported. + + For the content-based hash to be defined, the current block must be + full. + """ + return None + + +class BlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, prev_block: Optional[Block]) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int]) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks( + self, prev_block: Optional[Block], + block_token_ids: List[List[int]]) -> List[Block]: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def get_physical_block_id(self, absolute_id: int) -> int: + pass + + @abstractmethod + def swap_out(self, blocks: List[Block]) -> None: + pass + + @abstractmethod + def swap_in(self, blocks: List[Block]) -> None: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def promote_to_immutable_block(self, block: Block) -> BlockId: + """NOTE: This should not be used besides Block""" + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + class NoFreeBlocksError(ValueError): + pass + + +class DeviceAwareBlockAllocator(ABC): + + @abstractmethod + def allocate_mutable_block(self, prev_block: Optional[Block], + device: Device) -> Block: + pass + + @abstractmethod + def allocate_immutable_block(self, prev_block: Optional[Block], + token_ids: List[int], + device: Device) -> Block: + pass + + @abstractmethod + def allocate_immutable_blocks(self, prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device) -> List[Block]: + pass + + @abstractmethod + def get_num_free_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def get_num_total_blocks(self, device: Device) -> int: + pass + + @abstractmethod + def free(self, block: Block) -> None: + pass + + @abstractmethod + def fork(self, last_block: Block) -> List[Block]: + pass + + @property + @abstractmethod + def all_block_ids(self) -> FrozenSet[int]: + pass + + @abstractmethod + def clear_copy_on_writes(self) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + pass + + @abstractmethod + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + pass + + @abstractmethod + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + pass + + @abstractmethod + def get_num_full_blocks_touched(self, blocks: List[Block], + device: Device) -> int: + pass + + @abstractmethod + def swap(self, blocks: List[Block], src_device: Device, + dst_device: Device) -> Dict[int, int]: + pass + + @abstractmethod + def get_physical_block_id(self, device: Device, absolute_id: int) -> int: + pass + + @abstractmethod + def allocate_or_get_null_block(self) -> Block: + """ + Null blocks are used as a placeholders for KV cache blocks that have + been dropped due to sliding window. + There is at most one null block per allocator. + """ + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py new file mode 100644 index 0000000..9341a51 --- /dev/null +++ b/vllm/core/block/naive_block.py @@ -0,0 +1,449 @@ +from collections import deque +from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple + +from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device + +Refcount = int + + +class NaiveBlockAllocator(BlockAllocator): + """A simple block allocator that manages blocks of memory without prefix + caching. + + Args: + create_block (Block.Factory): A factory function for creating new + blocks. This is used when a NaiveBlockAllocator is composed within + a prefix caching allocator -- the naive block allocator must + construct prefix caching blocks (but shouldn't know anything else + about them). + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids (Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + def __init__( + self, + create_block: Block.Factory, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + block_pool: Optional[BlockPool] = None, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._free_block_indices: Deque[BlockId] = deque(block_ids) + self._all_block_indices = frozenset(block_ids) + assert len(self._all_block_indices) == num_blocks + + self._refcounter = RefCounter( + all_block_indices=self._free_block_indices) + self._block_size = block_size + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + if block_pool is None: + extra_factor = 4 + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + self._block_pool = BlockPool(self._block_size, create_block, self, + num_blocks * extra_factor) + else: + # In this case, the block pool is provided by the caller, + # which means that there is most likely a need to share + # a block pool between allocators + self._block_pool = block_pool + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: + """Allocates a new immutable block with the given token IDs, linked to + the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + token_ids (List[int]): The token IDs to be stored in the new block. + + Returns: + Block: The newly allocated immutable block. + """ + assert device is None + block = self.allocate_mutable_block(prev_block=prev_block) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Optional[Device] = None) -> List[Block]: + assert device is None + num_blocks = len(block_token_ids) + + block_ids = [] + for i in range(num_blocks): + block_ids.append(self._allocate_block_id()) + + blocks = [] + for i in range(num_blocks): + prev_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block_token_ids[i], + block_size=self._block_size, + physical_block_id=block_ids[i]) + blocks.append(prev_block) + + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: + """Allocates a new mutable block, linked to the previous block. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. If + None, then the block to be allocated is the first block in the + sequence. + + Returns: + Block: The newly allocated mutable block. + """ + assert device is None + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + return block + + def _allocate_block_id(self) -> BlockId: + if not self._free_block_indices: + raise BlockAllocator.NoFreeBlocksError() + + block_id = self._free_block_indices.popleft() + self._refcounter.incr(block_id) + return block_id + + def _free_block_id(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount == 0: + self._free_block_indices.appendleft(block_id) + + block.block_id = None + + def free(self, block: Block, keep_block_object: bool = False) -> None: + # Release the physical block id + self._free_block_id(block) + + # Release the block object + if not keep_block_object: + self._block_pool.free_block(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + + # Increment refcount for each block. + assert block.block_id is not None + refcount = self._refcounter.incr(block.block_id) + assert refcount != 1, "can't fork free'd block" + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block.block_id) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self) -> int: + return len(self._free_block_indices) + + def get_num_total_blocks(self) -> int: + return len(self._all_block_indices) + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The zero-offset block id on certain device. + """ + return sorted(self._all_block_indices).index(absolute_id) + + @property + def refcounter(self): + return self._refcounter + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._all_block_indices + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + """Mark blocks as computed, used in prefix caching. + + Since the naive allocator does not implement prefix caching, we do + nothing. + """ + pass + + def get_computed_block_ids(self, prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool) -> List[int]: + """No prefix caching here => return empty list + """ + return [] + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Determine blocks that can be skipped in prefill. + + Since the naive allocator does not support prefix caching, always return + an empty list. + """ + return [] + + def promote_to_immutable_block(self, block: Block) -> BlockId: + raise NotImplementedError("There is no promotion for naive blocks") + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + # NOTE: for naive block, we use set to eliminate common blocks among + # seqs, also we compare the empty slots in the mutable blocks with + # lookahead slots to get the number of unique new block that are + # needed. + old_block_set = set() + for block in blocks: + if block.is_full: + old_block_set.add(block) + return len(old_block_set) + + def swap_out(self, blocks: List[Block]) -> None: + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + tmp_block.block_id = None + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + def get_prefix_cache_hit_rate(self) -> float: + return -1 + + +class NaiveBlock(Block): + """An implementation of the Block class that does not support prefix + caching. + + The NaiveBlock class represents a block of token IDs with a fixed size. It + provides methods for appending token IDs to the block and manages copy-on + -write operations when necessary. + + Args: + prev_block (Block): The previous block in the sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The block allocator associated with this + block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None, which means no allocation has been + made. + _cow_target (Optional[Block], optional): The copy-on-write target block. + If not provided, it defaults to self. + """ + + def __init__(self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + _cow_target: Optional[Block] = None): + self._token_ids: List[int] = [] + self._block_size = block_size + self._prev_block = prev_block + self._block_id = block_id + self._allocator = allocator + self._cow_target = _cow_target if _cow_target is not None else self + + self._append_token_ids_no_cow(token_ids) + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and performs a + copy-on-write if necessary. + + Args: + token_ids (Optional[List[int]]): The token IDs to be appended + to the block. + """ + self._append_token_ids_no_cow(token_ids) + + if self._block_id is not None: + self._block_id = (self._allocator.cow_block_if_not_appendable( + self._cow_target)) + + def _append_token_ids_no_cow(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + if len(token_ids) == 0: + return + + assert len(token_ids) <= self.num_empty_slots + + self._token_ids.extend(token_ids) + + @property + def computed(self) -> bool: + raise NotImplementedError + + @computed.setter + def computed(self, value) -> None: + raise NotImplementedError + + @property + def last_accessed(self) -> float: + raise NotImplementedError + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + raise NotImplementedError + + @property + def block_id(self) -> Optional[int]: + return self._block_id + + @block_id.setter + def block_id(self, value: Optional[int]) -> None: + self._block_id = value + + @property + def is_full(self) -> bool: + return self.num_empty_slots == 0 + + @property + def num_empty_slots(self) -> int: + return self._block_size - len(self.token_ids) + + @property + def token_ids(self) -> List[int]: + return self._token_ids + + @property + def num_tokens_total(self) -> int: + raise NotImplementedError( + "num_tokens_total is not used for naive block") + + @property + def block_size(self) -> int: + return self._block_size + + @property + def prev_block(self) -> Optional["Block"]: + return self._prev_block + + @property + def content_hash(self) -> Optional[int]: + return None diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py new file mode 100644 index 0000000..7c8a2bc --- /dev/null +++ b/vllm/core/block/prefix_caching_block.py @@ -0,0 +1,970 @@ +"""Token blocks.""" +from os.path import commonprefix +from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple + +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, + get_all_blocks_recursively) +from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.naive_block import (BlockPool, NaiveBlock, + NaiveBlockAllocator) +from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor + +PrefixHash = int + +# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME +# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME, +# then we know this block hasn't been accessed yet. +_DEFAULT_LAST_ACCESSED_TIME = -1 + + +class BlockTracker: + """Used to track the status of a block inside the prefix caching allocator + """ + __slots__ = ("active", "last_accessed", "computed") + + def reset(self): + self.last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self.computed: bool = False + + def __init__(self): + self.active: bool = False + self.reset() + + def enable(self): + assert not self.active + self.active = True + self.reset() + + def disable(self): + assert self.active + self.active = False + self.reset() + + +class PrefixCachingBlockAllocator(BlockAllocator): + """A block allocator that implements prefix caching. + + The PrefixCachingBlockAllocator maintains a cache of blocks based on their + content hash. It reuses blocks with the same content hash to avoid redundant + memory allocation. The allocator also supports copy-on-write operations. + + Args: + num_blocks (int): The total number of blocks to manage. + block_size (int): The size of each block in tokens. + block_ids(Optional[Iterable[int]], optional): An optional iterable of + block IDs. If not provided, block IDs will be assigned sequentially + from 0 to num_blocks - 1. + """ + + def __init__( + self, + num_blocks: int, + block_size: int, + block_ids: Optional[Iterable[int]] = None, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU, + ): + if block_ids is None: + block_ids = range(num_blocks) + + self._block_size = block_size + + # A mapping of prefix hash to block index. All blocks which have a + # prefix hash will be in this dict, even if they have refcount 0. + self._cached_blocks: Dict[PrefixHash, BlockId] = {} + + # A list of immutable block IDs that have been touched by scheduler + # and should be marked as computed after an entire batch of sequences + # are scheduled. + self._touched_blocks: Set[BlockId] = set() + + # Used to track status of each physical block id + self._block_tracker: Dict[BlockId, BlockTracker] = {} + for block_id in block_ids: + self._block_tracker[block_id] = BlockTracker() + + # Pre-allocate "num_blocks * extra_factor" block objects. + # The "* extra_factor" is a buffer to allow more block objects + # than physical blocks + extra_factor = 4 + self._block_pool = BlockPool(self._block_size, self._create_block, + self, num_blocks * extra_factor) + + # An allocator for blocks that do not have prefix hashes. + self._hashless_allocator = NaiveBlockAllocator( + create_block=self._create_block, # type: ignore + num_blocks=num_blocks, + block_size=block_size, + block_ids=block_ids, + block_pool=self._block_pool, # Share block pool here + ) + + # Evitor used to maintain how we want to handle those computed blocks + # if we find memory pressure is high. + self.evictor: Evictor = make_evictor(eviction_policy) + + # We share the refcounter between allocators. This allows us to promote + # blocks originally allocated in the hashless allocator to immutable + # blocks. + self._refcounter = self._hashless_allocator.refcounter + + self._cow_tracker = CopyOnWriteTracker( + refcounter=self._refcounter.as_readonly()) + + self.metric_data = CacheMetricData() + + # Implements Block.Factory. + def _create_block( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + ) -> Block: + # Bind block to self. + allocator = self + + return PrefixCachingBlock( + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=allocator, + computed=computed, + ) + + def allocate_immutable_block(self, + prev_block: Optional[Block], + token_ids: List[int], + device: Optional[Device] = None) -> Block: + """Allocates an immutable block with the given token IDs, reusing cached + blocks if possible. + + Args: + prev_block (Optional[Block]): The previous block in the sequence. + token_ids (List[int]): The token IDs to be stored in the block. + + Returns: + Block: The allocated immutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + # First, try to create a block that points to cached data + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=token_ids, + block_size=self._block_size, + physical_block_id=None) + assert block.content_hash is not None + + cached_block_id = self._cached_blocks.get(block.content_hash, None) + if cached_block_id is not None: + self.metric_data.query(hit=True) + block.block_id = cached_block_id + self._incr_refcount_cached_block(block) + return block + self.metric_data.query(hit=False) + self._block_pool.free_block(block) + + # No cached block => Allocate a new block + block = self.allocate_mutable_block(prev_block) + block.append_token_ids(token_ids) + return block + + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Optional[Device] = None) -> List[Block]: + blocks = [] + for token_ids in block_token_ids: + prev_block = self.allocate_immutable_block(prev_block=prev_block, + token_ids=token_ids, + device=device) + blocks.append(prev_block) + return blocks + + def allocate_mutable_block(self, + prev_block: Optional[Block], + device: Optional[Device] = None) -> Block: + """Allocates a mutable block. If there are no free blocks, this will + evict unused cached blocks. + + Args: + prev_block (Block): The previous block in the sequence. + None is not allowed unlike it is super class. + + Returns: + Block: The allocated mutable block. + """ + assert device is None + assert_prefix_caching_block_or_none(prev_block) + + block_id = self._allocate_block_id() + block = self._block_pool.init_block(prev_block=prev_block, + token_ids=[], + block_size=self._block_size, + physical_block_id=block_id) + assert not block.computed + assert block.content_hash is None + return block + + def _incr_refcount_cached_block(self, block: Block) -> None: + # Set this block to be "computed" since it is pointing to a + # cached block id (which was already computed) + block.computed = True + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + if refcount == 1: + # In case a cached block was evicted, restore its tracking + if block_id in self.evictor: + self.evictor.remove(block_id) + + self._track_block_id(block_id, computed=True) + + def _decr_refcount_cached_block(self, block: Block) -> None: + # Ensure this is immutable/cached block + assert block.content_hash is not None + + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.decr(block_id) + if refcount > 0: + block.block_id = None + return + else: + assert refcount == 0 + + # No longer used + assert block.content_hash in self._cached_blocks + + # Add the cached block to the evictor + # (This keeps the cached block around so it can be reused) + self.evictor.add(block_id, block.content_hash, block.num_tokens_total, + self._block_tracker[block_id].last_accessed) + + # Stop tracking the block + self._untrack_block_id(block_id) + + block.block_id = None + + def _decr_refcount_hashless_block(self, block: Block) -> None: + block_id = block.block_id + assert block_id is not None + + # We may have a fork case where block is shared, + # in which case, we cannot remove it from tracking + refcount = self._refcounter.get(block_id) + if refcount == 1: + self._untrack_block_id(block_id) + + # Decrement refcount of the block_id, but do not free the block object + # itself (will be handled by the caller) + self._hashless_allocator.free(block, keep_block_object=True) + + def _allocate_block_id(self) -> BlockId: + """First tries to allocate a block id from the hashless allocator, + and if there are no blocks, then tries to evict an unused cached block. + """ + hashless_block_id = self._maybe_allocate_hashless_block_id() + if hashless_block_id is not None: + return hashless_block_id + + evicted_block_id = self._maybe_allocate_evicted_block_id() + if evicted_block_id is not None: + return evicted_block_id + + # No block available in hashless allocator, nor in unused cache blocks. + raise BlockAllocator.NoFreeBlocksError() + + def _maybe_allocate_hashless_block_id(self) -> Optional[BlockId]: + try: + # Allocate mutable block and extract its block_id + block = self._hashless_allocator.allocate_mutable_block( + prev_block=None) + block_id = block.block_id + self._block_pool.free_block(block) + + self._track_block_id(block_id, computed=False) + return block_id + except BlockAllocator.NoFreeBlocksError: + return None + + def _maybe_allocate_evicted_block_id(self) -> Optional[BlockId]: + if self.evictor.num_blocks == 0: + return None + + # Here we get an evicted block, which is only added + # into evictor if its ref counter is 0 + # and since its content would be changed, we need + # to remove it from _cached_blocks's tracking list + block_id, content_hash_to_evict = self.evictor.evict() + + # Sanity checks + assert content_hash_to_evict in self._cached_blocks + _block_id = self._cached_blocks[content_hash_to_evict] + assert self._refcounter.get(_block_id) == 0 + assert _block_id == block_id + + self._cached_blocks.pop(content_hash_to_evict) + + self._refcounter.incr(block_id) + self._track_block_id(block_id, computed=False) + + return block_id + + def _free_block_id(self, block: Block) -> None: + """Decrements the refcount of the block. The block may be in two + possible states: (1) immutable/cached or (2) mutable/hashless. + In the first case, the refcount is decremented directly and the block + may be possibly added to the evictor. In other case, hashless + allocator free(..) with keep_block_object=True is called to only free + the block id (since the block object may be reused by the caller) + """ + block_id = block.block_id + assert block_id is not None, "Freeing unallocated block is undefined" + + if block.content_hash is not None: + # Immutable: This type of block is always cached, and we want to + # keep it in the evictor for future reuse + self._decr_refcount_cached_block(block) + else: + # Mutable: This type of block is not cached, so we release it + # directly to the hashless allocator + self._decr_refcount_hashless_block(block) + + assert block.block_id is None + + def free(self, block: Block, keep_block_object: bool = False) -> None: + """Release the block (look at free_block_id(..) docs) + """ + # Release the physical block index + self._free_block_id(block) + + # Release the block object to the pool + if not keep_block_object: + self._block_pool.free_block(block) + + def fork(self, last_block: Block) -> List[Block]: + """Creates a new sequence of blocks that shares the same underlying + memory as the original sequence. + + Args: + last_block (Block): The last block in the original sequence. + + Returns: + List[Block]: The new sequence of blocks that shares the same memory + as the original sequence. + """ + source_blocks = get_all_blocks_recursively(last_block) + + forked_blocks: List[Block] = [] + prev_block = None + for block in source_blocks: + block_id = block.block_id + assert block_id is not None + + refcount = self._refcounter.incr(block_id) + assert refcount != 1, "can't fork free'd block_id = {}".format( + block_id) + + forked_block = self._block_pool.init_block( + prev_block=prev_block, + token_ids=block.token_ids, + block_size=self._block_size, + physical_block_id=block_id) + + forked_blocks.append(forked_block) + prev_block = forked_blocks[-1] + + return forked_blocks + + def get_num_free_blocks(self, device: Optional[Device] = None) -> int: + assert device is None + # The number of free blocks is the number of hashless free blocks + # plus the number of blocks evictor could free from its list. + return self._hashless_allocator.get_num_free_blocks( + ) + self.evictor.num_blocks + + def get_num_total_blocks(self) -> int: + return self._hashless_allocator.get_num_total_blocks() + + def get_physical_block_id(self, absolute_id: int) -> int: + """Returns the zero-offset block id on certain block allocator + given the absolute block id. + + Args: + absolute_id (int): The absolute block id for the block + in whole allocator. + + Returns: + int: The rzero-offset block id on certain device. + """ + return sorted(self.all_block_ids).index(absolute_id) + + @property + def all_block_ids(self) -> FrozenSet[int]: + return self._hashless_allocator.all_block_ids + + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + + def is_block_cached(self, block: Block) -> bool: + assert block.content_hash is not None + return block.content_hash in self._cached_blocks + + def promote_to_immutable_block(self, block: Block) -> BlockId: + """Once a mutable block is full, it can be promoted to an immutable + block. This means that its content can be referenced by future blocks + having the same prefix. + + Note that if we already have a cached block with the same content, we + will replace the newly-promoted block's mapping with the existing cached + block id. + + Args: + block: The mutable block to be promoted. + + Returns: + BlockId: Either the original block index, or the block index of + the previously cached block matching the same content. + """ + # Ensure block can be promoted + assert block.content_hash is not None + assert block.block_id is not None + assert self._refcounter.get(block.block_id) > 0 + + if block.content_hash not in self._cached_blocks: + # No cached content hash => Set this block as cached. + # Note that this block cannot be marked as computed yet + # because other sequences in the same batch cannot reuse + # this block. + self._cached_blocks[block.content_hash] = block.block_id + # Mark this block as touched so that it can be marked as + # computed after the entire batch of sequences are scheduled. + self._touched_blocks.add(block.block_id) + return block.block_id + + # Reuse the cached content hash + self._decr_refcount_hashless_block(block) + block.block_id = self._cached_blocks[block.content_hash] + + # Increment refcount of the cached block and (possibly) restore + # it from the evictor. + # Note that in this case, the block is marked as computed + self._incr_refcount_cached_block(block) + + return block.block_id + + def cow_block_if_not_appendable(self, block: Block) -> BlockId: + """Performs a copy-on-write operation on the given block if it is not + appendable. + + Args: + block (Block): The block to check for copy-on-write. + + Returns: + BlockId: The block index of the new block if a copy-on-write + operation was performed, or the original block index if + no copy-on-write was necessary. + """ + src_block_id = block.block_id + assert src_block_id is not None + + if self._cow_tracker.is_appendable(block): + return src_block_id + + self._free_block_id(block) + trg_block_id = self._allocate_block_id() + + self._cow_tracker.record_cow(src_block_id, trg_block_id) + + return trg_block_id + + def clear_copy_on_writes(self) -> List[Tuple[BlockId, BlockId]]: + """Returns the copy-on-write source->destination mapping and clears it. + + Returns: + List[Tuple[BlockId, BlockId]]: A list mapping source + block indices to destination block indices. + """ + return self._cow_tracker.clear_cows() + + def mark_blocks_as_accessed(self, block_ids: List[int], + now: float) -> None: + """Mark blocks as accessed, used in prefix caching. + + If the block is added into evictor, we need to update corresponding + info in evictor's metadata. + """ + + for block_id in block_ids: + if self._block_tracker[block_id].active: + self._block_tracker[block_id].last_accessed = now + elif block_id in self.evictor: + self.evictor.update(block_id, now) + else: + raise ValueError( + "Mark block as accessed which is not belonged to GPU") + + def mark_blocks_as_computed(self, block_ids: List[int]) -> None: + # Mark all touched blocks as computed. + for block_id in self._touched_blocks: + self._block_tracker[block_id].computed = True + self._touched_blocks.clear() + + def _track_block_id(self, block_id: Optional[BlockId], + computed: bool) -> None: + assert block_id is not None + self._block_tracker[block_id].enable() + self._block_tracker[block_id].computed = computed + + def _untrack_block_id(self, block_id: Optional[BlockId]) -> None: + assert block_id is not None + self._block_tracker[block_id].disable() + + def block_is_computed(self, block_id: int) -> bool: + if self._block_tracker[block_id].active: + return self._block_tracker[block_id].computed + else: + return block_id in self.evictor + + def get_computed_block_ids(self, + prev_computed_block_ids: List[int], + block_ids: List[int], + skip_last_block_id: bool = True) -> List[int]: + prev_prefix_size = len(prev_computed_block_ids) + cur_size = len(block_ids) + if skip_last_block_id: + cur_size -= 1 + + # Sanity checks + assert cur_size >= 0 + assert prev_prefix_size <= cur_size + + ret = prev_computed_block_ids + for i in range(prev_prefix_size, cur_size): + block_id = block_ids[i] + if self.block_is_computed(block_id): + ret.append(block_id) + return ret + + def get_common_computed_block_ids( + self, computed_seq_block_ids: List[List[int]]) -> List[int]: + """Return the block ids that are common for a given sequence group. + + Only those blocks that are immutable and already be marked + compyted would be taken consideration. + """ + + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + + # It returns a list of int although type annotation says list of string. + if len(computed_seq_block_ids) == 1: + return computed_seq_block_ids[0] + + return commonprefix([ + ids for ids in computed_seq_block_ids # type: ignore + if ids + ]) + + def get_num_full_blocks_touched(self, blocks: List[Block]) -> int: + """Returns the number of full blocks that will be touched by + swapping in/out. + + Args: + blocks: List of blocks to be swapped. + Returns: + int: the number of full blocks that will be touched by + swapping in/out the given blocks. Non full blocks are ignored + when deciding the number of blocks to touch. + """ + num_touched_blocks: int = 0 + for block in blocks: + # If the block has a match in the cache and the cached + # block is not referenced, then we still count it as a + # touched block + if block.is_full and (not self.is_block_cached(block) or \ + (block.content_hash is not None and \ + self._cached_blocks[block.content_hash] in \ + self.evictor)): + num_touched_blocks += 1 + return num_touched_blocks + + def swap_out(self, blocks: List[Block]) -> None: + """Execute the swap out actions. Basically just free the + given blocks. + + Args: + blocks: List of blocks to be swapped out. + """ + for block in blocks: + self._free_block_id(block) + + def swap_in(self, blocks: List[Block]) -> None: + """Execute the swap in actions. Change the block id from + old allocator to current allocator for each block to finish + the block table update. + + Args: + blocks: List of blocks to be swapped in. + """ + for block in blocks: + # Here we allocate either immutable or mutable block and then + # extract its block_id. Note that the block object is released + # and the block_id is assigned to "block" to allow reusing the + # existing "block" object + if block.is_full: + tmp_block = self.allocate_immutable_block( + prev_block=block.prev_block, token_ids=block.token_ids) + else: + tmp_block = self.allocate_mutable_block( + prev_block=block.prev_block) + tmp_block.append_token_ids(block.token_ids) + + block_id = tmp_block.block_id + self._block_pool.free_block(tmp_block) + + block.block_id = block_id # Assign block_id + + +class PrefixCachingBlock(Block): + """A block implementation that supports prefix caching. + + The PrefixCachingBlock class represents a block of token IDs with prefix + caching capabilities. It wraps a NaiveBlock internally and provides + additional functionality for content hashing and promoting immutable blocks + with the prefix caching allocator. + + Args: + prev_block (Optional[PrefixCachingBlock]): The previous block in the + sequence. + token_ids (List[int]): The initial token IDs to be stored in the block. + block_size (int): The maximum number of token IDs that can be stored in + the block. + allocator (BlockAllocator): The prefix + caching block allocator associated with this block. + block_id (Optional[int], optional): The physical block index + of this block. Defaults to None. + """ + + def __init__( + self, + prev_block: Optional[Block], + token_ids: List[int], + block_size: int, + allocator: BlockAllocator, + block_id: Optional[int] = None, + computed: bool = False, + ): + assert isinstance(allocator, PrefixCachingBlockAllocator), ( + "Currently this class is only tested with " + "PrefixCachingBlockAllocator. Got instead allocator = {}".format( + allocator)) + assert_prefix_caching_block_or_none(prev_block) + + self._prev_block = prev_block + self._cached_content_hash: Optional[int] = None + self._cached_num_tokens_total: int = 0 + self._allocator = allocator + self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME + self._computed = computed + + # On the first time, we create the block object, and next we only + # reinitialize it + if hasattr(self, "_block"): + self._block.__init__( # type: ignore[has-type] + prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + else: + self._block = NaiveBlock(prev_block=prev_block, + token_ids=token_ids, + block_size=block_size, + block_id=block_id, + allocator=self._allocator) + + self._update_num_tokens_total() + + def _update_num_tokens_total(self): + """Incrementally computes the number of tokens that there is + till the current block (included) + """ + res = 0 + + # Add all previous blocks + if self._prev_block is not None: + res += self._prev_block.num_tokens_total + + # Add current block + res += len(self.token_ids) + + self._cached_num_tokens_total = res + + @property + def computed(self) -> bool: + return self._computed + + @computed.setter + def computed(self, value) -> None: + self._computed = value + + @property + def last_accessed(self) -> float: + return self._last_accessed + + @last_accessed.setter + def last_accessed(self, last_accessed_ts: float): + self._last_accessed = last_accessed_ts + + def append_token_ids(self, token_ids: List[int]) -> None: + """Appends the given token IDs to the block and registers the block as + immutable if the block becomes full. + + Args: + token_ids (List[int]): The token IDs to be appended to the block. + """ + # Ensure this is mutable block (not promoted) + assert self.content_hash is None + assert not self.computed + + if len(token_ids) == 0: + return + + # Ensure there are input tokens + assert token_ids, "Got token_ids = {}".format(token_ids) + + # Naive block handles CoW. + self._block.append_token_ids(token_ids) + self._update_num_tokens_total() + + # If the content hash is present, then the block can be made immutable. + # Register ourselves with the allocator, potentially replacing the + # physical block index. + if self.content_hash is not None: + self.block_id = self._allocator.promote_to_immutable_block(self) + + @property + def block_id(self) -> Optional[int]: + return self._block.block_id + + @block_id.setter + def block_id(self, value) -> None: + self._block.block_id = value + + @property + def is_full(self) -> bool: + return self._block.is_full + + @property + def num_empty_slots(self) -> int: + return self._block.num_empty_slots + + @property + def num_tokens_total(self) -> int: + return self._cached_num_tokens_total + + @property + def block_size(self) -> int: + return self._block.block_size + + @property + def token_ids(self) -> List[int]: + return self._block.token_ids + + @property + def prev_block(self) -> Optional[Block]: + return self._prev_block + + @property + def content_hash(self) -> Optional[int]: + """Return the content-based hash of the current block, or None if it is + not yet defined. + + For the content-based hash to be defined, the current block must be + full. + """ + # If the hash is already computed, return it. + if self._cached_content_hash is not None: + return self._cached_content_hash + + # We cannot compute a hash for the current block because it is not full. + if not self.is_full: + return None + + is_first_block = self._prev_block is None + prev_block_hash = ( + None if is_first_block else + self._prev_block.content_hash # type: ignore + ) + + # Previous block exists but does not yet have a hash. + # Return no hash in this case. + if prev_block_hash is None and not is_first_block: + return None + + self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block, + prev_block_hash, + cur_block_token_ids=self.token_ids) + return self._cached_content_hash + + @staticmethod + def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], + cur_block_token_ids: List[int]) -> int: + """Computes a hash value corresponding to the contents of a block and + the contents of the preceding block(s). The hash value is used for + prefix caching. + + NOTE: Content-based hashing does not yet support LoRA. + + Parameters: + - is_first_block (bool): A flag indicating if the block is the first in + the sequence. + - prev_block_hash (Optional[int]): The hash of the previous block. None + if this is the first block. + - cur_block_token_ids (List[int]): A list of token ids in the current + block. The current block is assumed to be full. + + Returns: + - int: The computed hash value for the block. + """ + assert (prev_block_hash is None) == is_first_block + return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) + + +class ComputedBlocksTracker: + """Handles caching of per-sequence computed block ids. + When a sequence appears for the first time, it traverses all of the + blocks and detects the prefix of blocks that is computed. On the + subsequent times, it only traverses the new blocks that were added + and updates the already recorded prefix of blocks with the newly + computed blocks. + + To avoid redundant traversals, the algorithm also detects when there + is a "gap" in the computed prefix. For example, if we have blocks = + [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then + we won't try to add more computed blocks to [1,2,3] in this sequence + iteration, and will add more computed blocks only after the sequence is + freed and reused again. + + Note that currently, for a given sequence, we also skip the last + block id for caching purposes, to avoid caching of a full sequence + """ + + def __init__(self, allocator): + self._allocator = allocator + self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], + bool]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._cached_computed_seq_blocks + self._cached_computed_seq_blocks[seq_id] = ([], False) + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._cached_computed_seq_blocks + del self._cached_computed_seq_blocks[seq_id] + + def get_cached_computed_blocks_and_update( + self, seq_id: int, block_ids: List[int]) -> List[int]: + """ Look at the class documentation for details + """ + # Ensure seq_id is already tracked + assert seq_id in self._cached_computed_seq_blocks + + # Get cached data (may be empty on the first time) + prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ + seq_id] + + if has_gap: + # When gap is detected, we do not add more computed blocks at this + # sequence iteration + return prev_computed_block_ids + + # We do not consider the last block id for caching purposes. + num_cur_blocks = len(block_ids) - 1 + assert num_cur_blocks >= 0 + + if len(prev_computed_block_ids) >= num_cur_blocks: + # Cache HIT + assert len(prev_computed_block_ids) == num_cur_blocks + return prev_computed_block_ids + + # If here, then we may possibly add more computed blocks. As a result, + # traverse the additional blocks after prev_computed_block_ids to + # detect more computed blocks and add them. + + # Incremental init for seq_id => Look only at the new blocks + computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 + prev_computed_block_ids, + block_ids, + skip_last_block_id= + True, # We skip last block id to avoid caching of full seq + ) + + # Detect if there is a "gap" + has_gap = len(computed_block_ids) < num_cur_blocks + + # Record + self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, + has_gap) + + return computed_block_ids + + +class LastAccessBlocksTracker: + """Manages the last access time of the tracked sequences, in order to allow + an efficient update of allocator's block last access times + """ + + def __init__(self, allocator): + self._allocator = allocator + self._seq_last_access: Dict[int, Optional[float]] = {} + + def add_seq(self, seq_id: int) -> None: + """Start tracking seq_id + """ + assert seq_id not in self._seq_last_access + self._seq_last_access[seq_id] = None + + def remove_seq(self, seq_id: int) -> None: + """Stop tracking seq_id + """ + assert seq_id in self._seq_last_access + del self._seq_last_access[seq_id] + + def update_last_access(self, seq_id: int, time: float) -> None: + assert seq_id in self._seq_last_access + self._seq_last_access[seq_id] = time + + def update_seq_blocks_last_access(self, seq_id: int, + block_ids: List[int]) -> None: + assert seq_id in self._seq_last_access + + ts = self._seq_last_access[seq_id] + + if ts is None: + # No last access was recorded, no need to update. + return + + self._allocator.mark_blocks_as_accessed(block_ids, ts) + + +def assert_prefix_caching_block_or_none(block: Optional[Block]): + if block is None: + return + assert isinstance(block, + PrefixCachingBlock), "Got block = {}".format(block) diff --git a/vllm/core/block/utils.py b/vllm/core/block/utils.py new file mode 100644 index 0000000..2883943 --- /dev/null +++ b/vllm/core/block/utils.py @@ -0,0 +1,48 @@ +"""Block manager utils.""" +from vllm.sequence import SequenceGroup +from vllm.utils import (STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + STR_NOT_IMPL_ENC_DEC_SWA) + + +def _get_block_mgr_sliding_window_attr(block_mgr): + ''' + BlockManagerV1 and BlockManagerV2 have slightly different + members related to sliding window attention (SWA). This + function extracts the appropriate member to use for determining + whether SWA is enabled. + + Arguments: + + * block_mgr: BlockManagerV1 or BlockManagerV2 instance + ''' + + if hasattr(block_mgr, 'block_sliding_window'): + return block_mgr.block_sliding_window + if hasattr(block_mgr, 'max_block_sliding_window'): + return block_mgr.max_block_sliding_window + + raise AttributeError("Block manager instance has neither " + \ + "block_sliding_window nor " + \ + "max_block_sliding_window attributes.") + + +def check_no_caching_or_swa_for_blockmgr_encdec( + block_mgr, seq_group: SequenceGroup) -> None: + ''' + Enforce that prefix caching & sliding-window attention (SWA) + are currently unsupported *specifically* for encoder/decoder models. + + Raises NotImplementedError if unsupported scenario is detected. + + Arguments: + + * block_mgr: BlockSpaceManager instance + * seq_group: SequenceGroup passed to block_mgr + ''' + + if seq_group.is_encoder_decoder(): + if _get_block_mgr_sliding_window_attr(block_mgr) is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_SWA) + + if block_mgr.enable_caching: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py new file mode 100644 index 0000000..8bc0ce2 --- /dev/null +++ b/vllm/core/block_manager_v1.py @@ -0,0 +1,743 @@ +"""A block manager that manages token blocks.""" +import math +from abc import ABC, abstractmethod +from itertools import count, takewhile +from os.path import commonprefix +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple + +from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.common import CacheMetricData +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +logger = init_logger(__name__) + + +class BlockAllocatorBase(ABC): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + @abstractmethod + def __init__(self, + device: Device, + block_size: int, + num_blocks: int, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU): + pass + + @abstractmethod + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + pass + + @abstractmethod + def free(self, block: PhysicalTokenBlock) -> None: + pass + + @abstractmethod + def get_num_free_blocks(self) -> int: + pass + + @abstractmethod + def get_num_total_blocks(self) -> int: + pass + + @abstractmethod + def contains_block(self, block_hash: int) -> bool: + pass + + @abstractmethod + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + + +class CachedBlockAllocator(BlockAllocatorBase): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__(self, + device: Device, + block_size: int, + num_blocks: int, + eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + self.current_num_blocks = 0 + self.cached_blocks: Dict[int, PhysicalTokenBlock] = {} + + self.evictor: Evictor = make_evictor(eviction_policy) + + self.default_hash_ctr = count() + + self.cache_metric_data = CacheMetricData() + + def allocate_block(self, block_hash: int, + num_hashed_tokens: int) -> PhysicalTokenBlock: + if self.current_num_blocks == self.num_blocks: + block = self.evictor.evict() + block.block_hash = block_hash + block.num_hashed_tokens = num_hashed_tokens + return block + block = PhysicalTokenBlock(device=self.device, + block_number=self.current_num_blocks, + block_size=self.block_size, + block_hash=block_hash, + num_hashed_tokens=num_hashed_tokens) + self.current_num_blocks += 1 + return block + + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + if block_hash is None: + block_hash = next(self.default_hash_ctr) + + if block_hash in self.evictor: + assert block_hash not in self.cached_blocks + block = self.evictor.remove(block_hash) + assert block.ref_count == 0 + self.cached_blocks[block_hash] = block + + if block_hash in self.cached_blocks: + self.cache_metric_data.query(hit=True) + else: + self.cache_metric_data.query(hit=False) + self.cached_blocks[block_hash] = self.allocate_block( + block_hash, num_hashed_tokens) + block = self.cached_blocks[block_hash] + assert block.block_hash == block_hash + block.ref_count += 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + assert block.block_hash not in self.evictor + self.evictor.add(block) + + # Remove the block from the cached_blocks + del self.cached_blocks[block.block_hash] + + def get_num_free_blocks(self) -> int: + return (self.num_blocks - self.current_num_blocks + + self.evictor.num_blocks) + + def get_num_total_blocks(self) -> int: + return self.num_blocks + + def contains_block(self, block_hash: int) -> bool: + return block_hash in self.cached_blocks or block_hash in self.evictor + + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + # Update the hash of block and the cached_blocks dictionary. + assert not self.contains_block(block_hash) + old_hash = block.block_hash + block.block_hash = block_hash + del self.cached_blocks[old_hash] + self.cached_blocks[block_hash] = block + + def get_prefix_cache_hit_rate(self) -> float: + return self.cache_metric_data.get_hit_rate() + + +class UncachedBlockAllocator(BlockAllocatorBase): + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + device: Device, + block_size: int, + num_blocks: int, + ) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: List[PhysicalTokenBlock] = [] + for i in range(num_blocks): + block = PhysicalTokenBlock(device=device, + block_number=i, + block_size=block_size, + block_hash=-1, + num_hashed_tokens=0) + self.free_blocks.append(block) + + def allocate(self, + block_hash: Optional[int] = None, + num_hashed_tokens: int = 0) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + block.ref_count = 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + def get_num_total_blocks(self) -> int: + return self.num_blocks + + def contains_block(self, block_hash: int) -> bool: + raise NotImplementedError( + "Invalid codepath for uncached block allocator.") + + def update_hash(self, block_hash: int, block: PhysicalTokenBlock): + raise NotImplementedError( + "Invalid codepath for uncached block allocator.") + + def get_prefix_cache_hit_rate(self) -> float: + return -1 + + +class BlockSpaceManagerV1(BlockSpaceManager): + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + if enable_caching and sliding_window is not None: + raise NotImplementedError( + "Sliding window is not allowed with prefix caching enabled!") + + self.block_sliding_window = None + if sliding_window is not None: + # Round up to nearest block size to regularize sliding window + # allocation sizes. + self.block_sliding_window = math.ceil(sliding_window / block_size) + + self.watermark = watermark + assert watermark >= 0.0 + + self.enable_caching = enable_caching + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + if self.enable_caching: + logger.info("Automatic prefix caching is enabled.") + self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) + else: + self.gpu_allocator = UncachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator = UncachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) + # Mapping: seq_id -> BlockTable. + self.block_tables: Dict[int, BlockTable] = {} + + # Mapping: req_id -> BlockTable + # Note that each SequenceGroup has a unique + # request ID + self.cross_block_tables: Dict[str, BlockTable] = {} + + def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: + return 0 if seq is None else seq.n_blocks + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + self_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_seqs(status=SequenceStatus.WAITING)[0]) + cross_num_required_blocks = self._get_seq_num_required_blocks( + seq_group.get_encoder_seq()) + num_required_blocks = self_num_required_blocks + \ + cross_num_required_blocks + + if self.block_sliding_window is not None: + + num_required_blocks = min(num_required_blocks, + self.block_sliding_window) + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _allocate_sequence(self, \ + seq: Optional[Sequence], \ + ref_count: int, \ + is_encoder_decoder: bool = True) -> BlockTable: + # Allocate new physical token blocks that will store the prompt tokens. + num_prompt_blocks = self._get_seq_num_required_blocks(seq) + + block_table: BlockTable = BlockTable() + assert seq is not None + for logical_idx in range(num_prompt_blocks): + if (self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window): + block = block_table[logical_idx % self.block_sliding_window] + # Set the reference counts of the token blocks. + block.ref_count = ref_count + elif not is_encoder_decoder and self.enable_caching: + block = self.gpu_allocator.allocate( + seq.hash_of_block(logical_idx), + seq.num_hashed_tokens_of_block(logical_idx)) + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = ref_count + block_table.append(block) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + is_encoder_decoder = seq_group.is_encoder_decoder() + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + # Allocate decoder sequences + # + # NOTE: Here we assume that all sequences in the group have the same + # decoder prompt. + wait_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + seq = wait_seqs[0] + block_table: BlockTable = \ + self._allocate_sequence(seq, + seq_group.num_seqs(), + is_encoder_decoder) + + # Assign the self-attention block tables for each sequence. + if len(wait_seqs) == 1: + self.block_tables[seq.seq_id] = block_table + else: + for seq in wait_seqs: + self.block_tables[seq.seq_id] = block_table.copy() + + # Allocate encoder sequence + if is_encoder_decoder: + # A SequenceGroup has only a single encoder sequence (at most), + # thus allocate with a ref count of 1 + block_table = self._allocate_sequence(seq_group.get_encoder_seq(), + 1, is_encoder_decoder) + # Assign the cross-attention block table for the SequenceGroup. + self.cross_block_tables[seq_group.request_id] = block_table + + def can_append_slots(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> bool: + assert (num_lookahead_slots == 0 + ), "lookahead allocation not supported in BlockSpaceManagerV1" + + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + return num_seqs <= num_free_gpu_blocks + + def _promote_last_block( + self, + seq: Sequence, + last_block: PhysicalTokenBlock, + ) -> PhysicalTokenBlock: + assert self.enable_caching + + # Compute a new hash for the block so that it can be shared by other + # Sequences + new_hash = seq.hash_of_block(seq.n_blocks - 1) + + # if new_hash is already in the cached table, then free last_block + # and return the cached version + if self.gpu_allocator.contains_block(new_hash): + self.gpu_allocator.free(last_block) + return self.gpu_allocator.allocate(new_hash) + else: + self.gpu_allocator.update_hash(new_hash, last_block) + return last_block + + def _is_last_block_full( + self, + seq: Sequence, + ) -> bool: + token_ids_len = seq.data.get_len() + return token_ids_len > 0 and token_ids_len % seq.block_size == 0 + + def _maybe_promote_last_block( + self, + seq: Sequence, + last_block: PhysicalTokenBlock, + ) -> PhysicalTokenBlock: + if self._is_last_block_full(seq): + return self._promote_last_block(seq, last_block) + else: + return last_block + + def _allocate_last_physical_block( + self, + seq: Sequence, + ) -> PhysicalTokenBlock: + # Called before a new block is appended. + # This is in charge of allocating a new physical block (to be appended). + + # None if the last block is not full. Otherwise, we set it to the + # content hash. + if not self.enable_caching: + return self.gpu_allocator.allocate() + block_hash: Optional[int] = None + n_blocks = seq.n_blocks + if (self._is_last_block_full(seq)): + block_hash = seq.hash_of_block(n_blocks - 1) + num_hashed_tokens = seq.num_hashed_tokens_of_block(n_blocks - 1) + + # num_hashed_tokens is used to compute future hashes + # (e.g. in the hashing function, it is used to ask the sequence for + # prefix tokens) + new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens) + + # If the block_hash is None, then the block is not full. + # If the block is not full, then we expect it to have a refcount of 1. + if block_hash is None: + assert new_block.ref_count == 1 + return new_block + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int = 0, + ) -> List[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + n_blocks = seq.n_blocks + block_table = self.block_tables[seq.seq_id] + # If we need to allocate a new physical block + if len(block_table) < n_blocks: + # Currently this code only supports adding one physical block + assert len(block_table) == n_blocks - 1 + + if (self.block_sliding_window + and len(block_table) >= self.block_sliding_window): + # reuse a block + block_table.append(block_table[len(block_table) % + self.block_sliding_window]) + else: + # The sequence hash a new logical block. + # Allocate a new physical block. + new_block = self._allocate_last_physical_block(seq) + block_table.append(new_block) + return [] + + # We want to append the token to the last physical block. + last_block = block_table[-1] + assert last_block.device == Device.GPU + if last_block.ref_count == 1: + # Not shared with other sequences. Appendable. + if self.enable_caching: + # If the last block is now complete, we may reuse an old block + # to save memory. + maybe_new_block = self._maybe_promote_last_block( + seq, last_block) + block_table[-1] = maybe_new_block + return [] + else: + # The last block is shared with other sequences. + # Copy on Write: Allocate a new block and copy the tokens. + new_block = self._allocate_last_physical_block(seq) + + block_table[-1] = new_block + self.gpu_allocator.free(last_block) + return [(last_block.block_number, new_block.block_number)] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + # NOTE: fork does not allocate a new physical block. + # Thus, it is always safe from OOM. + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.copy() + + # When using a sliding window, blocks will be eventually reused. + # In this case the block tables will contain repeated blocks. + # When forking, we must make sure that each block's `ref_count` + # is only incremented by one, so we deduplicate them by wrapping + # them in a set. + for block in set(src_block_table): + block.ref_count += 1 + + def _get_physical_blocks( + self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + + # NOTE: Here, we assume that the physical blocks are only shared by + # the sequences in the same group. + request_id = seq_group.request_id + blocks: Set[PhysicalTokenBlock] = set() + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + blocks.update(self.block_tables[seq.seq_id]) + # Cross-attention blocks + if seq_group.is_encoder_decoder(): + blocks.update(self.cross_block_tables[request_id]) + return list(blocks) + + def can_swap_in(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + assert (num_lookahead_slots == 0 + ), "BlockSpaceManagerV1 does not support lookahead allocation" + + blocks = self._get_physical_blocks(seq_group) + num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + if seq_group.is_encoder_decoder(): + num_swapped_seqs += 1 + num_free_blocks = self.gpu_allocator.get_num_free_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate + # at least one free block right after the swap-in. + # NOTE: This should match the logic in can_append_slot(). + num_required_blocks = len(blocks) + num_swapped_seqs + if self.gpu_allocator.get_num_total_blocks() < num_required_blocks: + return AllocStatus.NEVER + elif num_free_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _swap_block_table( + self, block_table: BlockTable, src_allocator: BlockAllocatorBase, + dest_allocator: BlockAllocatorBase, + mapping: Dict[PhysicalTokenBlock, + PhysicalTokenBlock]) -> BlockTable: + new_block_table: BlockTable = BlockTable() + + for from_block in block_table: + if from_block in mapping: + to_block = mapping[from_block] + to_block.ref_count += 1 + else: + to_block = dest_allocator.allocate( + from_block.block_hash, from_block.num_hashed_tokens) + mapping[from_block] = to_block + new_block_table.append(to_block) + # Free the source block swapped in to destination. + src_allocator.free(from_block) + + return new_block_table + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + + request_id = seq_group.request_id + + # CPU block -> GPU block. + # dict is efficient in lookup `if cpu_block in mapping` + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.cpu_allocator, self.gpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.cpu_allocator, + self.gpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + request_id = seq_group.request_id + + # GPU block -> CPU block. + # dict is efficient in lookup `if gpu_block in mapping` + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + self.block_tables[seq.seq_id] = \ + self._swap_block_table(self.block_tables[seq.seq_id], + self.gpu_allocator, self.cpu_allocator, + mapping) + + if seq_group.is_encoder_decoder(): + self.cross_block_tables[request_id] = \ + self._swap_block_table(self.cross_block_tables[request_id], + self.gpu_allocator, + self.cpu_allocator, + mapping) + + return [(cpu_block.block_number, gpu_block.block_number) + for cpu_block, gpu_block in mapping.items()] + + def _free_block_table(self, block_table: BlockTable) -> None: + # when using a sliding window, each seq will only use up + # to `self.block_sliding_window` blocks. When freeing + # the block table, we must make sure to not free blocks more + # than once. If no sliding window is used, there is no block + # reuse in the block table, so we must free all blocks. + blocks_to_free = (block_table[-self.block_sliding_window:] + if self.block_sliding_window is not None else + block_table) + for block in set(blocks_to_free): + if block.device == Device.GPU: + self.gpu_allocator.free(block) + else: + self.cpu_allocator.free(block) + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + block_table = self.block_tables[seq.seq_id] + self._free_block_table(block_table) + del self.block_tables[seq.seq_id] + + def free_cross(self, seq_group: SequenceGroup) -> None: + if seq_group.request_id not in self.cross_block_tables: + # Already freed or hasn't ben scheduled yet. + return + block_table = self.cross_block_tables[seq_group.request_id] + self._free_block_table(block_table) + del self.cross_block_tables[seq_group.request_id] + + def reset(self) -> None: + # Free decoder block tables + for block_table in self.block_tables.values(): + self._free_block_table(block_table) + self.block_tables.clear() + # Free cross-attention block tables + for block_table in self.cross_block_tables.values(): + self._free_block_table(block_table) + self.cross_block_tables.clear() + + def get_block_table(self, seq: Sequence) -> List[int]: + return self.block_tables[seq.seq_id].ids() + + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + block_table = self.cross_block_tables[seq_group.request_id] + return [block.block_number for block in block_table] + + def get_num_free_gpu_blocks(self) -> int: + return self.gpu_allocator.get_num_free_blocks() + + def get_num_free_cpu_blocks(self) -> int: + return self.cpu_allocator.get_num_free_blocks() + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + if self.enable_caching: + # Update the last accessed time of all the blocks accessed + # in this step. + block_table = self.block_tables[seq.seq_id] + for block in block_table: + block.last_accessed = access_time + + def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): + if seq.seq_id not in self.block_tables: + return + + # When chunked prefill is enabled, the computed full blocks + # should be calculated based on the number of computed tokens. + max_computed_tokens = (seq.data.get_num_computed_tokens() + + token_chunk_size) + computed_full_blocks = max_computed_tokens // self.block_size + + block_table = self.block_tables[seq.seq_id] + if computed_full_blocks == 0: + return + for i in reversed(range(computed_full_blocks)): + if block_table[i].computed: + break + block_table[i].computed = True + + def get_all_computed_blocks(self, seq: Sequence) -> List[int]: + if seq.seq_id not in self.block_tables: + return [] + block_table = self.block_tables[seq.seq_id] + # NOTE We exclude the last block to avoid the case where the entire + # prompt is cached. This would cause erroneous behavior in model + # runner. + return [ + b.block_number + for b in takewhile(lambda b: b.computed, block_table[:-1]) + ] + + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + """Return the block ids that are common for a given sequence group. + + Used in prefill (can skip prefill of some blocks). + """ + # Can return non-empty result only with prefix caching enabled. + if not self.enable_caching: + return [] + + ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] + return commonprefix([ids for ids in ids_list if ids != []]) + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + if self.enable_caching: + for seq in seq_group.get_seqs(): + self.compute_full_blocks_in_seq(seq, token_chunk_size) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + if device == Device.GPU: + return self.gpu_allocator.get_prefix_cache_hit_rate() + if device == Device.CPU: + return self.cpu_allocator.get_prefix_cache_hit_rate() + raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py new file mode 100644 index 0000000..cb047c8 --- /dev/null +++ b/vllm/core/block_manager_v2.py @@ -0,0 +1,505 @@ +"""A block manager that manages token blocks.""" +from typing import Dict, List, Optional +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.core.block.block_table import BlockTable +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator +from vllm.core.block.interfaces import Block +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + LastAccessBlocksTracker) +from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + +SeqId = int +EncoderSeqId = str + + +class BlockSpaceManagerV2(BlockSpaceManager): + """BlockSpaceManager which manages the allocation of KV cache. + + It owns responsibility for allocation, swapping, allocating memory for + autoregressively-generated tokens, and other advanced features such as + prefix caching, forking/copy-on-write, and sliding-window memory allocation. + + This class implements the design described in + https://github.com/vllm-project/vllm/pull/3492. + + Lookahead slots + The block manager has the notion of a "lookahead slot". These are slots + in the KV cache that are allocated for a sequence. Unlike the other + allocated slots, the content of these slots is undefined -- the worker + may use the memory allocations in any way. + + In practice, a worker could use these lookahead slots to run multiple + forward passes for a single scheduler invocation. Each successive + forward pass would write KV activations to the corresponding lookahead + slot. This allows low inter-token latency use-cases, where the overhead + of continuous batching scheduling is amortized over >1 generated tokens. + + Speculative decoding uses lookahead slots to store KV activations of + proposal tokens. + + See https://github.com/vllm-project/vllm/pull/3250 for more information + on lookahead scheduling. + + Args: + block_size (int): The size of each memory block. + num_gpu_blocks (int): The number of memory blocks allocated on GPU. + num_cpu_blocks (int): The number of memory blocks allocated on CPU. + watermark (float, optional): The threshold used for memory swapping. + Defaults to 0.01. + sliding_window (Optional[int], optional): The size of the sliding + window. Defaults to None. + enable_caching (bool, optional): Flag indicating whether caching is + enabled. Defaults to False. + """ + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + enable_caching: bool = False, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.sliding_window = sliding_window + # max_block_sliding_window is the max number of blocks that need to be + # allocated + self.max_block_sliding_window = None + if sliding_window is not None: + # +1 here because // rounds down + num_blocks = sliding_window // block_size + 1 + # +1 here because the last block may not be full, + # and so the sequence stretches one more block at the beginning + # For example, if sliding_window is 3 and block_size is 4, + # we may need 2 blocks when the second block only holds 1 token. + self.max_block_sliding_window = num_blocks + 1 + + self.watermark = watermark + assert watermark >= 0.0 + + self.enable_caching = enable_caching + + self.watermark_blocks = int(watermark * num_gpu_blocks) + + self.block_allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching" if enable_caching else "naive", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=block_size, + ) + + self.block_tables: Dict[SeqId, BlockTable] = {} + self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} + + self._computed_blocks_tracker = ComputedBlocksTracker( + self.block_allocator) + self._last_access_blocks_tracker = LastAccessBlocksTracker( + self.block_allocator) + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = BlockTable.get_num_required_blocks( + seq.get_token_ids(), + block_size=self.block_size, + num_lookahead_slots=num_lookahead_slots, + ) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + num_required_blocks += BlockTable.get_num_required_blocks( + encoder_seq.get_token_ids(), + block_size=self.block_size, + ) + + if self.max_block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.max_block_sliding_window) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + device=Device.GPU) + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def _allocate_sequence(self, seq: Sequence) -> BlockTable: + block_table = BlockTable( + block_size=self.block_size, + block_allocator=self.block_allocator, + max_block_sliding_window=self.max_block_sliding_window, + ) + if seq.get_token_ids(): + # Add blocks to the block table only if the sequence is non empty. + block_table.allocate(seq.get_token_ids()) + + return block_table + + def allocate(self, seq_group: SequenceGroup) -> None: + + # Allocate self-attention block tables for decoder sequences + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert not (set(seq.seq_id for seq in waiting_seqs) + & self.block_tables.keys()), "block table already exists" + + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = waiting_seqs[0] + block_table: BlockTable = self._allocate_sequence(seq) + self.block_tables[seq.seq_id] = block_table + + # Track seq + self._computed_blocks_tracker.add_seq(seq.seq_id) + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Assign the block table for each sequence. + for seq in waiting_seqs[1:]: + self.block_tables[seq.seq_id] = block_table.fork() + + # Track seq + self._computed_blocks_tracker.add_seq(seq.seq_id) + self._last_access_blocks_tracker.add_seq(seq.seq_id) + + # Allocate cross-attention block table for encoder sequence + # + # NOTE: Here we assume that all sequences in the group have the same + # encoder prompt. + request_id = seq_group.request_id + + assert (request_id + not in self.cross_block_tables), \ + "block table already exists" + + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) + + if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + block_table = self._allocate_sequence(encoder_seq) + self.cross_block_tables[request_id] = block_table + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + """Determine if there is enough space in the GPU KV cache to continue + generation of the specified sequence group. + + We use a worst-case heuristic: assume each touched block will require a + new allocation (either via CoW or new block). We can append slots if the + number of touched blocks is less than the number of free blocks. + + "Lookahead slots" are slots that are allocated in addition to the slots + for known tokens. The contents of the lookahead slots are not defined. + This is used by speculative decoding when speculating future tokens. + """ + + num_touched_blocks = 0 + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + block_table = self.block_tables[seq.seq_id] + + num_touched_blocks += ( + block_table.get_num_blocks_touched_by_append_slots( + token_ids=block_table.get_unseen_token_ids( + seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( + Device.GPU) + return num_touched_blocks <= num_free_gpu_blocks + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + + block_table = self.block_tables[seq.seq_id] + + block_table.append_token_ids( + token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + num_computed_slots=seq.data.get_num_computed_tokens(), + ) + # Return any new copy-on-writes. + new_cows = self.block_allocator.clear_copy_on_writes() + return new_cows + + def free(self, seq: Sequence) -> None: + seq_id = seq.seq_id + + if seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + + # Update seq block ids with the latest access time + self._last_access_blocks_tracker.update_seq_blocks_last_access( + seq_id, self.block_tables[seq.seq_id].physical_block_ids) + + # Untrack seq + self._last_access_blocks_tracker.remove_seq(seq_id) + self._computed_blocks_tracker.remove_seq(seq_id) + + # Free table/blocks + self.block_tables[seq_id].free() + del self.block_tables[seq_id] + + def free_cross(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.cross_block_tables: + # Already freed or hasn't been scheduled yet. + return + self.cross_block_tables[request_id].free() + del self.cross_block_tables[request_id] + + def get_block_table(self, seq: Sequence) -> List[int]: + block_ids = self.block_tables[seq.seq_id].physical_block_ids + return block_ids # type: ignore + + def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.cross_block_tables + block_ids = self.cross_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids # type: ignore + + def access_all_blocks_in_seq(self, seq: Sequence, now: float): + if self.enable_caching: + # Record the latest access time for the sequence. The actual update + # of the block ids is deferred to the sequence free(..) call, since + # only during freeing of block ids, the blocks are actually added to + # the evictor (which is when the most updated time is required) + # (This avoids expensive calls to mark_blocks_as_accessed(..)) + self._last_access_blocks_tracker.update_last_access( + seq.seq_id, now) + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + # If prefix caching is enabled, mark immutable blocks as computed + # right after they have been scheduled (for prefill). This assumes + # the scheduler is synchronous so blocks are actually computed when + # scheduling the next batch. + self.block_allocator.mark_blocks_as_computed([]) + + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + """Determine which blocks for which we skip prefill. + + With prefix caching we can skip prefill for previously-generated blocks. + Currently, the attention implementation only supports skipping cached + blocks if they are a contiguous prefix of cached blocks. + + This method determines which blocks can be safely skipped for all + sequences in the sequence group. + """ + computed_seq_block_ids = [] + for seq in seqs: + computed_seq_block_ids.append( + self._computed_blocks_tracker. + get_cached_computed_blocks_and_update( + seq.seq_id, + self.block_tables[seq.seq_id].physical_block_ids)) + + # NOTE(sang): This assumes seq_block_ids doesn't contain any None. + return self.block_allocator.get_common_computed_block_ids( + computed_seq_block_ids) # type: ignore + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + if parent_seq.seq_id not in self.block_tables: + # Parent sequence has either been freed or never existed. + return + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.fork() + + # Track child seq + self._computed_blocks_tracker.add_seq(child_seq.seq_id) + self._last_access_blocks_tracker.add_seq(child_seq.seq_id) + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + """Returns the AllocStatus for the given sequence_group + with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for the given sequence group. + """ + return self._can_swap(seq_group, Device.GPU, SequenceStatus.SWAPPED, + num_lookahead_slots) + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from CPU to GPU) generated by + swapping in the given seq_group with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from CPU + to GPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.CPU, + dst_device=Device.GPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id): + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id) + for cpu_block_id, gpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + """Returns whether we can swap out the given sequence_group + with num_lookahead_slots. + + Args: + seq_group (SequenceGroup): The sequence group to swap in. + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + bool: Whether it's possible to swap out current sequence group. + """ + alloc_status = self._can_swap(seq_group, Device.CPU, + SequenceStatus.RUNNING) + return alloc_status == AllocStatus.OK + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + """Returns the block id mapping (from GPU to CPU) generated by + swapping out the given sequence_group with num_lookahead_slots. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + + Returns: + List[Tuple[int, int]]: The mapping of swapping block from + GPU to CPU. + """ + physical_block_id_mapping = [] + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + blocks = self.block_tables[seq.seq_id].blocks + if len(blocks) == 0: + continue + + seq_swap_mapping = self.block_allocator.swap(blocks=blocks, + src_device=Device.GPU, + dst_device=Device.CPU) + + # Refresh the block ids of the table (post-swap) + self.block_tables[seq.seq_id].update(blocks) + + seq_physical_block_id_mapping = { + self.block_allocator.get_physical_block_id( + Device.GPU, gpu_block_id): + self.block_allocator.get_physical_block_id( + Device.CPU, cpu_block_id) + for gpu_block_id, cpu_block_id in seq_swap_mapping.items() + } + + physical_block_id_mapping.extend( + list(seq_physical_block_id_mapping.items())) + + return physical_block_id_mapping + + def get_num_free_gpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.GPU) + + def get_num_free_cpu_blocks(self) -> int: + return self.block_allocator.get_num_free_blocks(Device.CPU) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + + def _can_swap(self, + seq_group: SequenceGroup, + device: Device, + status: SequenceStatus, + num_lookahead_slots: int = 0) -> AllocStatus: + """Returns the AllocStatus for swapping in/out the given sequence_group + on to the 'device'. + + Args: + sequence_group (SequenceGroup): The sequence group to swap in. + device (Device): device to swap the 'seq_group' on. + status (SequenceStatus): The status of sequence which is needed + for action. RUNNING for swap out and SWAPPED for swap in + num_lookahead_slots (int): Number of lookahead slots used in + speculative decoding, default to 0. + + Returns: + AllocStatus: The AllocStatus for swapping in/out the given + sequence_group on to the 'device'. + """ + # First determine the number of blocks that will be touched by this + # swap. Then verify if there are available blocks in the device + # to perform the swap. + num_blocks_touched = 0 + blocks: List[Block] = [] + for seq in seq_group.get_seqs(status=status): + block_table = self.block_tables[seq.seq_id] + if block_table.blocks is not None: + # Compute the number blocks to touch for the tokens to be + # appended. This does NOT include the full blocks that need + # to be touched for the swap. + num_blocks_touched += \ + block_table.get_num_blocks_touched_by_append_slots( + block_table.get_unseen_token_ids(seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots) + blocks.extend(block_table.blocks) + # Compute the number of full blocks to touch and add it to the + # existing count of blocks to touch. + num_blocks_touched += self.block_allocator.get_num_full_blocks_touched( + blocks, device=device) + + watermark_blocks = 0 + if device == Device.GPU: + watermark_blocks = self.watermark_blocks + + if self.block_allocator.get_num_total_blocks( + device) < num_blocks_touched: + return AllocStatus.NEVER + elif self.block_allocator.get_num_free_blocks( + device) - num_blocks_touched >= watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER diff --git a/vllm/core/evictor_v1.py b/vllm/core/evictor_v1.py new file mode 100644 index 0000000..5db5a08 --- /dev/null +++ b/vllm/core/evictor_v1.py @@ -0,0 +1,106 @@ +import enum +from abc import ABC, abstractmethod +from typing import OrderedDict + +from vllm.block import PhysicalTokenBlock + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_hash: int) -> bool: + pass + + @abstractmethod + def evict(self) -> PhysicalTokenBlock: + """Runs the eviction algorithm and returns the evicted block""" + pass + + @abstractmethod + def add(self, block: PhysicalTokenBlock): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def remove(self, block_hash: int) -> PhysicalTokenBlock: + """Simply removes the block with the hash value block_hash from the + evictor. Caller is responsible for making sure that block_hash is + contained in the evictor before calling remove. Should be used to + "bring back" blocks that have been freed but not evicted yet. + """ + pass + + @property + @abstractmethod + def num_blocks(self) -> int: + pass + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict() + + def __contains__(self, block_hash: int) -> bool: + return block_hash in self.free_table + + def evict(self) -> PhysicalTokenBlock: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + evicted_block = next(iter(self.free_table.values())) + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _, block in self.free_table.items(): + if evicted_block.last_accessed < block.last_accessed: + break + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block = block + + self.free_table.pop(evicted_block.block_hash) + + evicted_block.computed = False + return evicted_block + + def add(self, block: PhysicalTokenBlock): + self.free_table[block.block_hash] = block + + def remove(self, block_hash: int) -> PhysicalTokenBlock: + if block_hash not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + block: PhysicalTokenBlock = self.free_table[block_hash] + self.free_table.pop(block_hash) + return block + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py new file mode 100644 index 0000000..0b943e6 --- /dev/null +++ b/vllm/core/evictor_v2.py @@ -0,0 +1,131 @@ +import enum +from abc import ABC, abstractmethod +from typing import OrderedDict, Tuple + + +class EvictionPolicy(enum.Enum): + """Enum for eviction policy used by make_evictor to instantiate the correct + Evictor subclass. + """ + LRU = enum.auto() + + +class Evictor(ABC): + """The Evictor subclasses should be used by the BlockAllocator class to + handle eviction of freed PhysicalTokenBlocks. + """ + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def __contains__(self, block_id: int) -> bool: + pass + + @abstractmethod + def evict(self) -> Tuple[int, int]: + """Runs the eviction algorithm and returns the evicted block's + content hash along with physical block id along with physical block id + """ + pass + + @abstractmethod + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + """Adds block to the evictor, making it a candidate for eviction""" + pass + + @abstractmethod + def update(self, block_id: int, last_accessed: float): + """Update corresponding block's access time in metadata""" + pass + + @abstractmethod + def remove(self, block_id: int): + """Remove a given block id from the cache.""" + pass + + @property + @abstractmethod + def num_blocks(self) -> int: + pass + + +class BlockMetaData(): + """Data structure for storing key data describe cached block, so that + evitor could use to make its decision which one to choose for eviction + + Here we use physical block id as the dict key, as there maybe several + blocks with the same content hash, but their physical id is unique. + """ + + def __init__(self, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.content_hash = content_hash + self.num_hashed_tokens = num_hashed_tokens + self.last_accessed = last_accessed + + +class LRUEvictor(Evictor): + """Evicts in a least-recently-used order using the last_accessed timestamp + that's recorded in the PhysicalTokenBlock. If there are multiple blocks with + the same last_accessed time, then the one with the largest num_hashed_tokens + will be evicted. If two blocks each have the lowest last_accessed time and + highest num_hashed_tokens value, then one will be chose arbitrarily + """ + + def __init__(self): + self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict() + + def __contains__(self, block_id: int) -> bool: + return block_id in self.free_table + + def evict(self) -> Tuple[int, int]: + if len(self.free_table) == 0: + raise ValueError("No usable cache memory left") + + evicted_block, evicted_block_id = None, None + # The blocks with the lowest timestamps should be placed consecutively + # at the start of OrderedDict. Loop through all these blocks to + # find the one with maximum number of hashed tokens. + for _id, block in self.free_table.items(): + if evicted_block is None: + evicted_block, evicted_block_id = block, _id + continue + if evicted_block.last_accessed < block.last_accessed: + break + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block, evicted_block_id = block, _id + + assert evicted_block is not None + assert evicted_block_id is not None + self.free_table.pop(evicted_block_id) + + return evicted_block_id, evicted_block.content_hash + + def add(self, block_id: int, content_hash: int, num_hashed_tokens: int, + last_accessed: float): + self.free_table[block_id] = BlockMetaData(content_hash, + num_hashed_tokens, + last_accessed) + + def update(self, block_id: int, last_accessed: float): + self.free_table[block_id].last_accessed = last_accessed + + def remove(self, block_id: int): + if block_id not in self.free_table: + raise ValueError( + "Attempting to remove block that's not in the evictor") + self.free_table.pop(block_id) + + @property + def num_blocks(self) -> int: + return len(self.free_table) + + +def make_evictor(eviction_policy: EvictionPolicy) -> Evictor: + if eviction_policy == EvictionPolicy.LRU: + return LRUEvictor() + else: + raise ValueError(f"Unknown cache eviction policy: {eviction_policy}") diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py new file mode 100644 index 0000000..9e1d1b0 --- /dev/null +++ b/vllm/core/interfaces.py @@ -0,0 +1,127 @@ +import enum +from abc import ABC, abstractmethod +from typing import List +from typing import Sequence as GenericSequence +from typing import Tuple + +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager(ABC): + + @staticmethod + def get_block_space_manager_class(version: str): + version = version.lower() + + if version == "v1": + from vllm.core.block_manager_v1 import BlockSpaceManagerV1 + return BlockSpaceManagerV1 + + if version == "v2": + from vllm.core.block_manager_v2 import BlockSpaceManagerV2 + return BlockSpaceManagerV2 + + if version == "placeholder": + from vllm.core.placeholder_block_space_manager import ( + PlaceholderBlockSpaceManager) + return PlaceholderBlockSpaceManager + + raise ValueError(f"Unknown version {version=}") + + @abstractmethod + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + pass + + @abstractmethod + def allocate(self, seq_group: SequenceGroup) -> None: + pass + + @abstractmethod + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + pass + + @abstractmethod + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + @abstractmethod + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + pass + + @abstractmethod + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + pass + + @abstractmethod + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + pass + + @abstractmethod + def free(self, seq: Sequence) -> None: + pass + + @abstractmethod + def get_block_table(self, seq: Sequence) -> List[int]: + pass + + @abstractmethod + def get_num_free_gpu_blocks(self) -> int: + pass + + @abstractmethod + def get_num_free_cpu_blocks(self) -> int: + pass + + @abstractmethod + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + @abstractmethod + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: + pass + + @abstractmethod + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py new file mode 100644 index 0000000..a337392 --- /dev/null +++ b/vllm/core/placeholder_block_space_manager.py @@ -0,0 +1,91 @@ +from typing import List, Tuple + +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device + + +class PlaceholderBlockSpaceManager(BlockSpaceManager): + """A version of BlockSpaceManager for use in environments + where block management is not required. + For example: embedding models or attention-free models like Mamba. + + This class provides the same interface as BlockSpaceManager, but its + methods perform no actions or return simple values like True in specific + actions. It's designed to be used in scenarios where the overhead of + block management is unnecessary, such as in an embedding environment. + """ + + def __init__( + self, + **kwargs, + ) -> None: + pass + + def can_allocate(self, + seq_group: SequenceGroup, + num_lookahead_slots: int = 0) -> AllocStatus: + # Always return OK for dummy purposes + return AllocStatus.OK + + def allocate(self, seq_group: SequenceGroup) -> None: + # No actual allocation logic needed + pass + + def can_append_slots(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> bool: + return True + + def append_slots( + self, + seq: Sequence, + num_lookahead_slots: int, + ) -> List[Tuple[int, int]]: + return [] + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + pass + + def can_swap_in(self, seq_group: SequenceGroup, + num_lookahead_slots: int) -> AllocStatus: + return AllocStatus.OK + + def swap_in(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + return True + + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: + return None # type: ignore + + def free(self, seq: Sequence) -> None: + # No operation on free + return + + def get_block_table(self, seq: Sequence) -> List[int]: + return None # type: ignore + + def get_num_free_gpu_blocks(self) -> int: + return 1 + + def get_num_free_cpu_blocks(self) -> int: + return 1 + + def access_all_blocks_in_seq( + self, + seq: Sequence, + access_time: float, + ) -> None: + pass + + def get_common_computed_block_ids(self, + seq_group: List[Sequence]) -> List[int]: + return [] + + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): + pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py new file mode 100644 index 0000000..1f0a121 --- /dev/null +++ b/vllm/core/scheduler.py @@ -0,0 +1,1650 @@ +import enum +import os +import random +import time +from collections import deque +from dataclasses import dataclass, field +from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.interfaces import AllocStatus, BlockSpaceManager +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupMetadata, SequenceGroupMetadataDelta, + SequenceStatus) +from vllm.utils import Device, PyObjectCache + +logger = init_logger(__name__) + +# Test-only. If configured, decode is preempted with +# ARTIFICIAL_PREEMPTION_PROB% probability. +ENABLE_ARTIFICIAL_PREEMPT = bool( + os.getenv("VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT", False)) # noqa +ARTIFICIAL_PREEMPTION_PROB = 0.5 +ARTIFICIAL_PREEMPTION_MAX_CNT = 500 + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +@dataclass +class SchedulingBudget: + """The available slots for scheduling. + + TODO(sang): Right now, the budget is request_id-aware meaning it can ignore + budget update from the same request_id. It is because in normal scheduling + path, we update RUNNING num_seqs ahead of time, meaning it could be + updated more than once when scheduling RUNNING requests. Since this won't + happen if we only have chunked prefill scheduling, we can remove this + feature from the API when chunked prefill is enabled by default. + """ + token_budget: int + max_num_seqs: int + _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + _num_batched_tokens: int = 0 + _num_curr_seqs: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): + assert num_new_tokens != 0 + assert num_new_seqs != 0 + return (self.num_batched_tokens + num_new_tokens <= self.token_budget + and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) + + def remaining_token_budget(self): + return self.token_budget - self.num_batched_tokens + + def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + return + + self._request_ids_num_batched_tokens.add(req_id) + self._num_batched_tokens += num_batched_tokens + + def subtract_num_batched_tokens(self, req_id: str, + num_batched_tokens: int): + if req_id in self._request_ids_num_batched_tokens: + self._request_ids_num_batched_tokens.remove(req_id) + self._num_batched_tokens -= num_batched_tokens + + def add_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + return + + self._request_ids_num_curr_seqs.add(req_id) + self._num_curr_seqs += num_curr_seqs + + def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): + if req_id in self._request_ids_num_curr_seqs: + self._request_ids_num_curr_seqs.remove(req_id) + self._num_curr_seqs -= num_curr_seqs + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_seqs(self): + return self._num_curr_seqs + + +@dataclass +class ScheduledSequenceGroup: + # A sequence group that's scheduled. + seq_group: SequenceGroup + # The total chunk size (number of tokens) to process for next iteration. + # 1 for decoding. Same as prompt tokens for prefill, but if prefill is + # chunked, it can be smaller than that. + token_chunk_size: int + + +@dataclass +class SchedulerOutputs: + """The scheduling decision made from a scheduler.""" + # Scheduled sequence groups. + scheduled_seq_groups: Iterable[ScheduledSequenceGroup] + # Number of prefill groups scheduled. + num_prefill_groups: int + # Total number of batched tokens. + num_batched_tokens: int + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] + # Sequence groups that are going to be ignored. + ignored_seq_groups: List[SequenceGroup] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # The number of requests in the running queue + running_queue_size: int + preempted: int + + def __post_init__(self): + # Swap in and swap out should never happen at the same time. + assert not (self.blocks_to_swap_in and self.blocks_to_swap_out) + + self.num_loras: int = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + def _sort_by_lora_ids(self): + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return { + g.seq_group.lora_request + for g in self.scheduled_seq_groups + if g.seq_group.lora_request is not None + } + + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + + +@dataclass +class SchedulerRunningOutputs: + """The requests that are scheduled from a running queue. + + Could contain prefill (prefill that's chunked) or decodes. If there's not + enough memory, it can be preempted (for recompute) or swapped out. + """ + # Selected sequences that are running and in a decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are running and in a prefill phase. + # I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The preempted sequences. + preempted: List[SequenceGroup] + # Sequences that are swapped out. + swapped_out: List[SequenceGroup] + # The blocks to swap out. + blocks_to_swap_out: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + + # Optimization for fast-access to seq_group lists + decode_seq_groups_list: List[SequenceGroup] + prefill_seq_groups_list: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerRunningOutputs": + return SchedulerRunningOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + decode_seq_groups_list=[], + prefill_seq_groups_list=[], + ) + + +@dataclass +class SchedulerSwappedInOutputs: + """The requests that are scheduled from a swap queue. + + Could contain prefill (prefill that's chunked) or decodes. + """ + # Selected sequences that are going to be swapped in and is in a + # decoding phase. + decode_seq_groups: List[ScheduledSequenceGroup] + # Selected sequences that are going to be swapped in and in a prefill + # phase. I.e., it means the prefill has been chunked. + prefill_seq_groups: List[ScheduledSequenceGroup] + # The blocks to swap in. + blocks_to_swap_in: List[Tuple[int, int]] + # The blocks to copy. + blocks_to_copy: List[Tuple[int, int]] + # The number of slots for lookahead decoding. + num_lookahead_slots: int + # Infeasible sequence groups. + infeasible_seq_groups: List[SequenceGroup] + + @classmethod + def create_empty(cls) -> "SchedulerSwappedInOutputs": + return SchedulerSwappedInOutputs( + decode_seq_groups=[], + prefill_seq_groups=[], + blocks_to_swap_in=[], + blocks_to_copy=[], + num_lookahead_slots=0, + infeasible_seq_groups=[], + ) + + +@dataclass +class SchedulerPrefillOutputs: + """The requests that are scheduled from a waiting queue. + + Could contain a fresh prefill requests or preempted requests that need + to be recomputed from scratch. + """ + # Selected sequences for prefill. + seq_groups: List[ScheduledSequenceGroup] + # Ignored sequence groups. + ignored_seq_groups: List[SequenceGroup] + num_lookahead_slots: int + + @classmethod + def create_empty(cls) -> "SchedulerPrefillOutputs": + return SchedulerPrefillOutputs( + seq_groups=[], + ignored_seq_groups=[], + num_lookahead_slots=0, + ) + + +def seq_group_metadata_builder(): + return SequenceGroupMetadata(request_id="", + is_prompt=False, + seq_data={}, + sampling_params=None, + block_tables={}) + + +def scheduler_running_outputs_builder(): + return SchedulerRunningOutputs(decode_seq_groups=[], + prefill_seq_groups=[], + preempted=[], + swapped_out=[], + blocks_to_swap_out=[], + blocks_to_copy=[], + num_lookahead_slots=0, + prefill_seq_groups_list=[], + decode_seq_groups_list=[]) + + +def scheduled_seq_group_builder(): + return ScheduledSequenceGroup(SequenceGroup("", [], -1), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + pipeline_parallel_size: int = 1, + output_proc_callback: Optional[Callable] = None, + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + version = "v1" + if self.scheduler_config.use_v2_block_manager: + version = "v2" + if (self.scheduler_config.embedding_mode + or self.cache_config.is_attention_free): + version = "placeholder" + + BlockSpaceManagerImpl = BlockSpaceManager.get_block_space_manager_class( + version) + + num_gpu_blocks = cache_config.num_gpu_blocks + if num_gpu_blocks: + num_gpu_blocks //= pipeline_parallel_size + + num_cpu_blocks = cache_config.num_cpu_blocks + if num_cpu_blocks: + num_cpu_blocks //= pipeline_parallel_size + + # Create the block space manager. + self.block_manager = BlockSpaceManagerImpl( + block_size=self.cache_config.block_size, + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + sliding_window=self.cache_config.sliding_window, + enable_caching=self.cache_config.enable_prefix_caching) + + # Sequence groups in the WAITING state. + # Contain new prefill or preempted requests. + self.waiting: Deque[SequenceGroup] = deque() + # Sequence groups in the RUNNING state. + # Contain decode requests. + self.running: Deque[SequenceGroup] = deque() + # Sequence groups in the SWAPPED state. + # Contain decode requests that are swapped out. + self.swapped: Deque[SequenceGroup] = deque() + # Sequence groups finished requests ids since last step iteration. + # It lets the model know that any state associated with these requests + # can and must be released after the current step. + # This is used to evict the finished requests from the Mamba cache. + self._finished_requests_ids: List[str] = list() + # Time at previous scheduling step + self.prev_time = 0.0 + # Did we schedule a prompt at previous step? + self.prev_prompt = False + # Latency of the last prompt step + self.last_prompt_latency = 0.0 + # preemption mode, RECOMPUTE or SWAP + self.user_specified_preemption_mode = scheduler_config.preemption_mode + + # The following field is test-only. It is used to inject artificial + # preemption. + self.enable_artificial_preemption = ENABLE_ARTIFICIAL_PREEMPT + self.artificial_preempt_cnt = (ARTIFICIAL_PREEMPTION_MAX_CNT + if self.enable_artificial_preemption + else 0) + self.num_cumulative_preemption: int = 0 + + # Used to cache python objects + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback = output_proc_callback + self.use_async_output_proc = self.output_proc_callback is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters + + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + + @property + def num_decoding_tokens_per_seq(self) -> int: + """The number of new tokens.""" + return 1 + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def _add_seq_group_to_running(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the running queue. + # Only for testing purposes. + self.running.append(seq_group) + + def _add_seq_group_to_swapped(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the swapped queue. + # Only for testing purposes. + self.swapped.append(seq_group) + + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + """ + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + if not request_ids: + # Using 'break' here may add two extra iterations, + # but is acceptable to reduce complexity. + break + if seq_group.request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + request_ids.remove(seq_group.request_id) + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + # Remove the aborted request from the Mamba cache. + self._finished_requests_ids.append(aborted_group.request_id) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + + self._free_seq_group_cross_attn_blocks(aborted_group) + + def _free_seq_group_cross_attn_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + """ + Free a sequence group from a cross-attention block table. + Has no effect on decoder-only models. + """ + if seq_group.is_encoder_decoder(): + self.block_manager.free_cross(seq_group) + + def has_unfinished_seqs(self) -> bool: + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def get_and_reset_finished_requests_ids(self) -> List[str]: + """Flushes the list of request ids of previously finished seq_groups.""" + finished_requests_ids = self._finished_requests_ids + self._finished_requests_ids = list() + return finished_requests_ids + + def _schedule_running( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerRunningOutputs: + """Schedule sequence groups that are running. + + Running queue should include decode and chunked prefill requests. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any decodes are preempted. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any decodes are preempted. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerRunningOutputs. + """ + ret: SchedulerRunningOutputs = \ + self._scheduler_running_outputs_cache[self.cache_id].get_object() + ret.blocks_to_swap_out.clear() + ret.blocks_to_copy.clear() + ret.decode_seq_groups.clear() + ret.prefill_seq_groups.clear() + ret.preempted.clear() + ret.swapped_out.clear() + + ret.num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking) + + ret.decode_seq_groups_list.clear() + ret.prefill_seq_groups_list.clear() + + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_out: List[Tuple[int, int]] = ret.blocks_to_swap_out + blocks_to_copy: List[Tuple[int, int]] = ret.blocks_to_copy + + decode_seq_groups: List[ScheduledSequenceGroup] = ret.decode_seq_groups + prefill_seq_groups: List[ + ScheduledSequenceGroup] = ret.prefill_seq_groups + preempted: List[SequenceGroup] = ret.preempted + swapped_out: List[SequenceGroup] = ret.swapped_out + + running_queue = self.running + assert len(self._async_stopped) == 0 + while running_queue: + seq_group = running_queue[0] + num_running_tokens = self._get_num_new_tokens( + seq_group, SequenceStatus.RUNNING, enable_chunking, budget) + + if num_running_tokens == 0: + # No budget => Stop + break + + running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if self.use_async_output_proc and seq_group.seqs[0].get_len( + ) > self.scheduler_config.max_model_len: + self._async_stopped.append(seq_group) + continue + + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. + while not self._can_append_slots(seq_group, enable_chunking): + budget.subtract_num_batched_tokens(seq_group.request_id, + num_running_tokens) + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(seq_group.request_id, + num_running_seqs) + + if (curr_loras is not None and seq_group.lora_int_id > 0 + and seq_group.lora_int_id in curr_loras): + curr_loras.remove(seq_group.lora_int_id) + + # Determine victim sequence + cont_loop = True + if running_queue: + # Preempt the lowest-priority sequence group. + victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: + preempted_mode = self._preempt(victim_seq_group, + blocks_to_swap_out) + if preempted_mode == PreemptionMode.RECOMPUTE: + preempted.append(victim_seq_group) + else: + swapped_out.append(victim_seq_group) + + if not cont_loop: + break + else: + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + + scheduled_seq_group: ScheduledSequenceGroup = \ + self._scheduled_seq_group_cache[self.cache_id].get_object() + scheduled_seq_group.seq_group = seq_group + if is_prefill: + scheduled_seq_group.token_chunk_size = num_running_tokens + prefill_seq_groups.append(scheduled_seq_group) + ret.prefill_seq_groups_list.append(seq_group) + else: + scheduled_seq_group.token_chunk_size = 1 + decode_seq_groups.append(scheduled_seq_group) + ret.decode_seq_groups_list.append(seq_group) + + budget.add_num_batched_tokens(seq_group.request_id, + num_running_tokens) + # OPTIMIZATION: Note that get_max_num_running_seqs is + # expensive. For the default scheduling chase where + # enable_chunking is False, num_seqs are updated before running + # this method, so we don't have to update it again here. + if enable_chunking: + num_running_seqs = seq_group.get_max_num_running_seqs() + budget.add_num_seqs(seq_group.request_id, num_running_seqs) + if curr_loras is not None and seq_group.lora_int_id > 0: + curr_loras.add(seq_group.lora_int_id) + + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() + + return ret + + def _schedule_swapped( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerSwappedInOutputs: + """Schedule sequence groups that are swapped out. + + It schedules swapped requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are swapped in. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are swapped in. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerSwappedInOutputs. + """ + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: List[Tuple[int, int]] = [] + blocks_to_copy: List[Tuple[int, int]] = [] + decode_seq_groups: List[ScheduledSequenceGroup] = [] + prefill_seq_groups: List[ScheduledSequenceGroup] = [] + infeasible_seq_groups: List[SequenceGroup] = [] + + swapped_queue = self.swapped + + leftover_swapped: Deque[SequenceGroup] = deque() + while swapped_queue: + seq_group = swapped_queue[0] + + # If the sequence group cannot be swapped in, stop. + is_prefill = seq_group.is_prefill() + alloc_status = self.block_manager.can_swap_in( + seq_group, + self._get_num_lookahead_slots(is_prefill, enable_chunking)) + if alloc_status == AllocStatus.LATER: + break + elif alloc_status == AllocStatus.NEVER: + logger.warning( + "Failing the request %s because there's not enough kv " + "cache blocks to run the entire sequence.", + seq_group.request_id) + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_IGNORED + infeasible_seq_groups.append(seq_group) + swapped_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + swapped_queue.popleft() + continue + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.SWAPPED, + enable_chunking, budget) + + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + if lora_int_id > 0 and curr_loras is not None: + curr_loras.add(lora_int_id) + swapped_queue.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + is_prefill = seq_group.is_prefill() + if is_prefill: + prefill_seq_groups.append( + ScheduledSequenceGroup(seq_group, + token_chunk_size=num_new_tokens)) + else: + decode_seq_groups.append( + ScheduledSequenceGroup(seq_group, token_chunk_size=1)) + budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + swapped_queue.extendleft(leftover_swapped) + + return SchedulerSwappedInOutputs( + decode_seq_groups=decode_seq_groups, + prefill_seq_groups=prefill_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_copy=blocks_to_copy, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=False, enable_chunking=enable_chunking), + infeasible_seq_groups=infeasible_seq_groups, + ) + + def _get_prompt_limit(self, seq_group: SequenceGroup) -> int: + if self.scheduler_config.chunked_prefill_enabled and \ + not self.scheduler_config.is_multi_step: + prompt_limit = self.scheduler_config.max_model_len + else: + prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Model is fine tuned with long context. Return the fine tuned max_len. + if (seq_group.lora_request + and seq_group.lora_request.long_lora_max_len): + assert prompt_limit <= seq_group.lora_request.long_lora_max_len + return seq_group.lora_request.long_lora_max_len + else: + return prompt_limit + + def _get_priority(self, + seq_group: SequenceGroup) -> Tuple[Optional[int], float]: + """ Get the priority of the sequence group. + Highest preference to user-defined priority, followed by arrival time. + Args: + seq_group: The sequence group input. + Returns: + The priority of the sequence group. + """ + return seq_group.priority, seq_group.arrival_time + + def _schedule_priority_preemption( + self, + budget: SchedulingBudget, + ) -> int: + """Sorts waiting and running queue. Also, force preempt requests + from the running queue if their priority is lower. + Priority-based preemption is used with the priority policy. + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + Returns: + A count of priority-based preemptions. + """ + + waiting_queue = self.waiting + + running_queue = deque(sorted(self.running, key=self._get_priority)) + + blocks_to_swap_out: List[Tuple[int, int]] = [] + force_preemption_count = 0 + + if waiting_queue: + seq_group = waiting_queue.popleft() + num_new_seqs = seq_group.get_max_num_running_seqs() + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + False, budget) + + #Only preempt if priority inversion exists + while running_queue and self._get_priority( + running_queue[-1]) > self._get_priority(seq_group): + #Only preempt if waiting sequence cannot be allocated + can_allocate = self.block_manager.can_allocate(seq_group) + if (num_new_tokens and can_allocate == AllocStatus.OK + and budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + #Adjust budget to remove the victim sequence group + vseq_group = running_queue.pop() + num_running_tokens = self._get_num_new_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget) + budget.subtract_num_batched_tokens(vseq_group.request_id, + num_running_tokens) + num_running_seqs = vseq_group.get_max_num_running_seqs() + budget.subtract_num_seqs(vseq_group.request_id, + num_running_seqs) + + #Preempt out the victim sequence group + self._preempt(vseq_group, blocks_to_swap_out, + PreemptionMode.RECOMPUTE) + waiting_queue.appendleft(vseq_group) + force_preemption_count += 1 + #Put the sequence back into the waiting queue + waiting_queue.appendleft(seq_group) + + waiting_queue = deque(sorted(waiting_queue, key=self._get_priority)) + + self.waiting = waiting_queue + self.running = running_queue + return force_preemption_count + + def _schedule_prefills( + self, + budget: SchedulingBudget, + curr_loras: Optional[Set[int]], + enable_chunking: bool = False, + ) -> SchedulerPrefillOutputs: + """Schedule sequence groups that are in prefill stage. + + Note that the current scheduler treats PREEMPTED_FOR_RECOMPUTE + as a new prefill (that starts from beginning -> most recently generated + tokens). + + It schedules waiting requests as long as it fits `budget` and + curr_loras <= max_lora from the scheduling config. The input arguments + `budget` and `curr_loras` are updated based on scheduled seq_groups. + + Args: + budget: The scheduling budget. The argument is in-place updated + when any requests are scheduled. + curr_loras: Currently batched lora request ids. The argument is + in-place updated when any requests are scheduled. + enable_chunking: If True, seq group can be chunked and only a + chunked number of tokens are scheduled if + `budget.num_batched_tokens` has not enough capacity to schedule + all tokens. + + Returns: + SchedulerPrefillOutputs. + """ + ignored_seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] + + waiting_queue = self.waiting + + leftover_waiting_sequences: Deque[SequenceGroup] = deque() + while self._passed_delay(time.time()) and waiting_queue: + seq_group = waiting_queue[0] + + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + num_new_tokens = self._get_num_new_tokens(seq_group, + SequenceStatus.WAITING, + enable_chunking, budget) + if not enable_chunking: + num_prompt_tokens = waiting_seqs[0].get_len() + assert num_new_tokens == num_prompt_tokens + + prompt_limit = self._get_prompt_limit(seq_group) + if num_new_tokens > prompt_limit: + logger.warning( + "Input prompt (%d tokens) is too long" + " and exceeds limit of %d", num_new_tokens, prompt_limit) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + num_lookahead_slots: int = 0 + if self.scheduler_config.is_multi_step and enable_chunking: + num_lookahead_slots = self._get_num_lookahead_slots( + True, enable_chunking) + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate( + seq_group, num_lookahead_slots=num_lookahead_slots) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + "Input prompt (%d tokens) + lookahead slots (%d) is " + "too long and exceeds the capacity of block_manager", + num_new_tokens, num_lookahead_slots) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None + if (self.lora_enabled and lora_int_id > 0 + and lora_int_id not in curr_loras + and len(curr_loras) >= self.lora_config.max_loras): + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_new_tokens == 0 + or not budget.can_schedule(num_new_tokens=num_new_tokens, + num_new_seqs=num_new_seqs)): + break + + # Can schedule this request. + if curr_loras is not None and lora_int_id > 0: + curr_loras.add(lora_int_id) + waiting_queue.popleft() + self._allocate_and_set_running(seq_group) + + if enable_chunking and self.scheduler_config.is_multi_step: + blocks_to_copy: List[Tuple[int, int]] = [] + # init_multi_step_from_lookahead_slots happens in append_slots + self._append_slots(seq_group, blocks_to_copy, enable_chunking) + # This assert will trip when a copy-on-write happens. This is + # not a concern as the very first sequence-group block + # allocation happens above. Still, we have the assert to + # catch any edge-cases. + assert not blocks_to_copy + else: + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config. + num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_groups.append( + ScheduledSequenceGroup(seq_group=seq_group, + token_chunk_size=num_new_tokens)) + budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_seqs(seq_group.request_id, num_new_seqs) + + # Queue requests that couldn't be scheduled. + waiting_queue.extendleft(leftover_waiting_sequences) + if len(seq_groups) > 0: + self.prev_prompt = True + + return SchedulerPrefillOutputs( + seq_groups=seq_groups, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True, enable_chunking=enable_chunking)) + + def _schedule_default(self) -> SchedulerOutputs: + """Schedule queued requests. + + The current policy is designed to optimize the throughput. First, + it batches as many prefill requests as possible. And it schedules + decodes. If there's a pressure on GPU memory, decode requests can + be swapped or preempted. + """ + # Include running requests to the budget. + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + # Make sure we include num running seqs before scheduling prefill, + # so that we don't schedule beyond max_num_seqs for prefill. + for seq_group in self.running: + budget.add_num_seqs(seq_group.request_id, + seq_group.get_max_num_running_seqs()) + curr_loras = set( + seq_group.lora_int_id for seq_group in self.running + if seq_group.lora_int_id > 0) if self.lora_enabled else None + + prefills = SchedulerPrefillOutputs.create_empty() + running_scheduled = SchedulerRunningOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # If any requests are swapped, prioritized swapped requests. + if not self.swapped: + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=False) + + if len(prefills.seq_groups + ) == 0 and self.scheduler_config.policy == "priority": + self._schedule_priority_preemption(budget) + + # Don't schedule decodes if prefills are scheduled. + # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running + # only contains decode requests, not chunked prefills. + if len(prefills.seq_groups) == 0: + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=False) + + # If any sequence group is preempted, do not swap in any sequence + # group. because it means there's no slot for new running requests. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. + if len(prefills.seq_groups) > 0: + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + self.running.extend(running_scheduled.decode_seq_groups_list) + + if len(swapped_in.decode_seq_groups) > 0: + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + preempted = (len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)) + + # There should be no prefill from running queue because this policy + # doesn't allow chunked prefills. + assert len(running_scheduled.prefill_seq_groups) == 0 + assert len(swapped_in.prefill_seq_groups) == 0 + + # Merge lists + num_prefill_groups = len(prefills.seq_groups) + if num_prefill_groups > 0: + scheduled_seq_groups = prefills.seq_groups + scheduled_seq_groups.extend(running_scheduled.decode_seq_groups) + else: + scheduled_seq_groups = running_scheduled.decode_seq_groups + scheduled_seq_groups.extend(swapped_in.decode_seq_groups) + + blocks_to_copy = running_scheduled.blocks_to_copy + blocks_to_copy.extend(swapped_in.blocks_to_copy) + + ignored_seq_groups = prefills.ignored_seq_groups + ignored_seq_groups.extend(swapped_in.infeasible_seq_groups) + + return SchedulerOutputs( + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=preempted, + ) + + def _schedule_chunked_prefill(self) -> SchedulerOutputs: + """Schedule queued requests. + + Chunked prefill allows to chunk prefill requests, batch them together + with decode requests. This policy 1. schedule as many decoding requests + as possible. 2. schedule chunked prefill requests that are not + finished. 3. schedule swapped request. 4. schedule new prefill + requests. + + The policy can sustain the high GPU utilization because it can put + prefill and decodes requests to the same batch, while it improves + inter token latency because decodes requests don't need to be blocked + by prefill requests. + """ + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_seqs=self.scheduler_config.max_num_seqs, + ) + curr_loras: Set[int] = set() + + prefills = SchedulerPrefillOutputs.create_empty() + swapped_in = SchedulerSwappedInOutputs.create_empty() + + # Decoding should be always scheduled first by fcfs. + running_scheduled = self._schedule_running(budget, + curr_loras, + enable_chunking=True) + + # Schedule swapped out requests. + # If preemption happens, it means we don't have space for swap-in. + if len(running_scheduled.preempted) + len( + running_scheduled.swapped_out) == 0: + swapped_in = self._schedule_swapped(budget, curr_loras) + + # Schedule new prefills. + prefills = self._schedule_prefills(budget, + curr_loras, + enable_chunking=True) + + assert (budget.num_batched_tokens <= + self.scheduler_config.max_num_batched_tokens) + assert budget.num_curr_seqs <= self.scheduler_config.max_num_seqs + + # Update waiting requests. + self.waiting.extendleft(running_scheduled.preempted) + + # Update new running requests. + # By default, vLLM scheduler prioritizes prefills. + # Once chunked prefill is enabled, + # the policy is changed to prioritize decode requests. + self.running.extend( + [s.seq_group for s in swapped_in.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) + self.running.extend([s.seq_group for s in prefills.seq_groups]) + + # Update swapped requests. + self.swapped.extend(running_scheduled.swapped_out) + return SchedulerOutputs( + scheduled_seq_groups=(prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups), + num_prefill_groups=(len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)), + num_batched_tokens=budget.num_batched_tokens, + blocks_to_swap_in=swapped_in.blocks_to_swap_in, + blocks_to_swap_out=running_scheduled.blocks_to_swap_out, + blocks_to_copy=running_scheduled.blocks_to_copy + + swapped_in.blocks_to_copy, + ignored_seq_groups=prefills.ignored_seq_groups + + swapped_in.infeasible_seq_groups, + num_lookahead_slots=running_scheduled.num_lookahead_slots, + running_queue_size=len(self.running), + preempted=(len(running_scheduled.preempted) + + len(running_scheduled.swapped_out)), + ) + + def _schedule(self) -> SchedulerOutputs: + """Schedule queued requests.""" + if self.scheduler_config.chunked_prefill_enabled: + return self._schedule_chunked_prefill() + else: + return self._schedule_default() + + def _can_append_slots(self, seq_group: SequenceGroup, + enable_chunking: bool) -> bool: + """Determine whether or not we have enough space in the KV cache to + continue generation of the sequence group. + """ + # It is True only for testing case to trigger artificial preemption. + if (self.enable_artificial_preemption + and random.uniform(0, 1) < ARTIFICIAL_PREEMPTION_PROB + and self.artificial_preempt_cnt > 0): + self.artificial_preempt_cnt -= 1 + return False + + is_prefill = seq_group.is_prefill() + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + if is_prefill and num_lookahead_slots > 0: + # Appending prefill slots only happens multi-step and + # chunked-prefill are enabled together. + assert self.scheduler_config.is_multi_step and enable_chunking + + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + # async_output_proc is allowed only when we have a single sequence + # in the sequence group + no_single_seq = seq_group.sampling_params is None or ( + seq_group.sampling_params.n == 1) + return no_single_seq + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_start_time = time.perf_counter() + + scheduler_outputs: SchedulerOutputs = self._schedule() + now = time.time() + + if not self.cache_config.enable_prefix_caching: + common_computed_block_nums = [] + + allow_async_output_proc: bool = self.use_async_output_proc + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for i, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + token_chunk_size = scheduled_seq_group.token_chunk_size + seq_group.maybe_set_first_scheduled_time(now) + + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + + # seq_id -> SequenceData + seq_data: Dict[int, SequenceData] = {} + # seq_id -> physical block numbers + block_tables: Dict[int, List[int]] = {} + + if seq_group.is_encoder_decoder(): + # Encoder associated with SequenceGroup + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data + # Block table for cross-attention + # Also managed at SequenceGroup level + cross_block_table = self.block_manager.get_cross_block_table( + seq_group) + else: + encoder_seq_data = None + cross_block_table = None + + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + self.block_manager.access_all_blocks_in_seq(seq, now) + + if self.cache_config.enable_prefix_caching: + common_computed_block_nums = ( + self.block_manager.get_common_computed_block_ids( + seq_group.get_seqs(status=SequenceStatus.RUNNING))) + + do_sample = True + is_prompt = seq_group.is_prefill() + # We should send the metadata to workers when the first prefill + # is sent. Subsequent requests could be chunked prefill or decode. + is_first_prefill = False + if is_prompt: + seqs = seq_group.get_seqs() + # Prefill has only 1 sequence. + assert len(seqs) == 1 + num_computed_tokens = seqs[0].data.get_num_computed_tokens() + is_first_prefill = num_computed_tokens == 0 + # In the next iteration, all prompt tokens are not computed. + # It means the prefill is chunked, and we don't need sampling. + # NOTE: We use get_len instead of get_prompt_len because when + # a sequence is preempted, prefill includes previous generated + # output tokens. + if (token_chunk_size + num_computed_tokens < + seqs[0].data.get_len()): + do_sample = False + + # It assumes the scheduled_seq_groups is ordered by + # prefill < decoding. + if is_first_prefill or not self.scheduler_config.send_delta_data: + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + do_sample=do_sample, + pooling_params=seq_group.pooling_params, + token_chunk_size=token_chunk_size, + lora_request=seq_group.lora_request, + computed_block_nums=common_computed_block_nums, + encoder_seq_data=encoder_seq_data, + cross_block_table=cross_block_table, + state=seq_group.state, + # `multi_modal_data` will only be present for the 1st comm + # between engine and worker. + # the subsequent comms can still use delta, but + # `multi_modal_data` will be None. + multi_modal_data=seq_group.multi_modal_data + if scheduler_outputs.num_prefill_groups > 0 else None, + mm_processor_kwargs=seq_group.mm_processor_kwargs, + prompt_adapter_request=seq_group.prompt_adapter_request, + ) + else: + # When SPMD mode is enabled, we only send delta data except for + # the first request to reduce serialization cost. + seq_data_delta = {} + for id, data in seq_data.items(): + seq_data_delta[id] = data.get_delta_and_reset() + seq_group_metadata = SequenceGroupMetadataDelta( + seq_data_delta, + seq_group.request_id, + block_tables, + is_prompt, + do_sample=do_sample, + token_chunk_size=token_chunk_size, + computed_block_nums=common_computed_block_nums, + ) + seq_group_metadata_list.append(seq_group_metadata) + + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + + # Now that the batch has been created, we can assume all blocks in the + # batch will have been computed before the next scheduling invocation. + # This is because the engine assumes that a failure in model execution + # will crash the vLLM instance / will not retry. + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + self.block_manager.mark_blocks_as_computed( + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) + + self._seq_group_metadata_cache[self.next_cache_id].reset() + + scheduler_time = time.perf_counter() - scheduler_start_time + # Add this to scheduler time to all the sequences that are currently + # running. This will help estimate if the scheduler is a significant + # component in the e2e latency. + for seq_group in self.running: + if seq_group is not None and seq_group.metrics is not None: + if seq_group.metrics.scheduler_time is not None: + seq_group.metrics.scheduler_time += scheduler_time + else: + seq_group.metrics.scheduler_time = scheduler_time + + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + """Free a sequence from a block table.""" + self.block_manager.free(seq) + + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + def free_finished_seq_groups(self) -> None: + remaining: Deque[SequenceGroup] = deque() + for seq_group in self.running: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): + remaining.append(seq_group) + + self.running = remaining + + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slots(self, + seq_group: SequenceGroup, + blocks_to_copy: List[Tuple[int, int]], + enable_chunking: bool = False) -> None: + """Appends new slots to the sequences in the given sequence group. + + Args: + seq_group (SequenceGroup): The sequence group containing the + sequences to append slots to. + blocks_to_copy (List[Tuple[int, int]]): A list of tuple of two + ints, the first int is the source block index, and the second + int is the destination block index. This list is updated with + the new source and destination block indices for the appended + slots. + enable_chunking (bool): True if chunked prefill is enabled. + """ + is_prefill: bool = seq_group.is_prefill() + num_lookahead_slots: int = self._get_num_lookahead_slots( + is_prefill, enable_chunking) + + seq_group.init_multi_step_from_lookahead_slots( + num_lookahead_slots, + num_scheduler_steps=self.scheduler_config.num_scheduler_steps, + is_multi_step=self.scheduler_config.is_multi_step, + enable_chunking=enable_chunking) + + seq_status: Optional[SequenceStatus] = SequenceStatus.RUNNING + if self.scheduler_config.is_multi_step and enable_chunking: + # In multi-step chunked-prefill any sequence type can have + # slots appended. + seq_status = None + + for seq in seq_group.get_seqs(status=seq_status): + cows = self.block_manager.append_slots(seq, num_lookahead_slots) + if len(cows) > 0: + blocks_to_copy.extend(cows) + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + preemption_mode: Optional[PreemptionMode] = None, + ) -> PreemptionMode: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if self.user_specified_preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + + elif self.user_specified_preemption_mode == "swap": + preemption_mode = PreemptionMode.SWAP + else: + preemption_mode = PreemptionMode.RECOMPUTE + + if self.num_cumulative_preemption % 50 == 0: + logger.warning( + "Sequence group %s is preempted by %s mode because there is " + "not enough KV cache space. This can affect the end-to-end " + "performance. Increase gpu_memory_utilization or " + "tensor_parallel_size to provide more KV cache memory. " + "total_num_cumulative_preemption=%d", seq_group.request_id, + preemption_mode, self.num_cumulative_preemption + 1) + self.num_cumulative_preemption += 1 + + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + return preemption_mode + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.free_seq(seq) + seq.reset_state_for_recompute() + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: List[Tuple[int, int]], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: List[Tuple[int, int]], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.extend(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED + + def _passed_delay(self, now: float) -> bool: + if self.prev_prompt: + self.last_prompt_latency = now - self.prev_time + self.prev_time, self.prev_prompt = now, False + # Delay scheduling prompts to let waiting queue fill up + if self.scheduler_config.delay_factor > 0 and self.waiting: + earliest_arrival_time = min( + [e.metrics.arrival_time for e in self.waiting]) + passed_delay = ( + (now - earliest_arrival_time) > + (self.scheduler_config.delay_factor * self.last_prompt_latency) + or not self.running) + else: + passed_delay = True + return passed_delay + + def _get_num_lookahead_slots(self, is_prefill: bool, + enable_chunking: bool) -> int: + """The number of slots to allocate per sequence per step, beyond known + token ids. Speculative decoding uses these slots to store KV activations + of tokens which may or may not be accepted. + + Speculative decoding does not yet support prefill, so we do not perform + lookahead allocation for prefill. + + When chunking is enabled with multi-step, we allocate lookahead slots + for the prefills for when the prefills turn into decodes in the first + step. + """ + if is_prefill: + if self.scheduler_config.is_multi_step and enable_chunking: + # num_lookahead_slots was introduced in the context of decodes, + # in Speculative Decoding. + # When the num_scheduler_steps is 8, say, then the + # num_lookahead_slots is 7. Meaning, we are doing a 1-step of + # decode anyways and we wish to do 7 more. + # + # "lookaheads" for prefills, is introduced in support for + # Chunked-Prefill in Multi-Step. + return self.scheduler_config.num_lookahead_slots + 1 + else: + return 0 + + return self.scheduler_config.num_lookahead_slots + + def _get_num_new_tokens(self, seq_group: SequenceGroup, + status: SequenceStatus, enable_chunking: bool, + budget: SchedulingBudget) -> int: + """Get the next new tokens to compute for a given sequence group + that's in a given `status`. + + The API could chunk the number of tokens to compute based on `budget` + if `enable_chunking` is True. If a sequence group has multiple + sequences (e.g., running beam search), it means it is in decoding + phase, so chunking doesn't happen. + + Returns 0 if the new token cannot be computed due to token budget. + """ + num_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) + for seq in seqs: + num_new_tokens += seq.get_num_new_tokens() + assert num_new_tokens > 0 + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + if enable_chunking and len(seqs) == 1: + remaining_token_budget = budget.remaining_token_budget() + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens + elif self.cache_config.enable_prefix_caching: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = self.cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + if remaining_token_budget < num_new_tokens: + num_new_tokens = (remaining_token_budget // + block_size) * block_size + else: + num_new_tokens = min(num_new_tokens, remaining_token_budget) + return num_new_tokens diff --git a/vllm/distributed/__init__.py b/vllm/distributed/__init__.py new file mode 100644 index 0000000..db325cf --- /dev/null +++ b/vllm/distributed/__init__.py @@ -0,0 +1,3 @@ +from .communication_op import * +from .parallel_state import * +from .utils import * diff --git a/vllm/distributed/__pycache__/__init__.cpython-310.pyc b/vllm/distributed/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..96bce1f Binary files /dev/null and b/vllm/distributed/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/distributed/__pycache__/communication_op.cpython-310.pyc b/vllm/distributed/__pycache__/communication_op.cpython-310.pyc new file mode 100644 index 0000000..d3c14b3 Binary files /dev/null and b/vllm/distributed/__pycache__/communication_op.cpython-310.pyc differ diff --git a/vllm/distributed/__pycache__/parallel_state.cpython-310.pyc b/vllm/distributed/__pycache__/parallel_state.cpython-310.pyc new file mode 100644 index 0000000..d3779a1 Binary files /dev/null and b/vllm/distributed/__pycache__/parallel_state.cpython-310.pyc differ diff --git a/vllm/distributed/__pycache__/utils.cpython-310.pyc b/vllm/distributed/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..424b3ac Binary files /dev/null and b/vllm/distributed/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py new file mode 100644 index 0000000..e13505d --- /dev/null +++ b/vllm/distributed/communication_op.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/vllm/distributed/device_communicators/__init__.py b/vllm/distributed/device_communicators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/distributed/device_communicators/__pycache__/__init__.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..58d4309 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/cuda_wrapper.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/cuda_wrapper.cpython-310.pyc new file mode 100644 index 0000000..f0c4225 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/cuda_wrapper.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-310.pyc new file mode 100644 index 0000000..17f8c24 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/custom_all_reduce.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/custom_all_reduce_utils.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/custom_all_reduce_utils.cpython-310.pyc new file mode 100644 index 0000000..e215e78 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/custom_all_reduce_utils.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-310.pyc new file mode 100644 index 0000000..222c197 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/pynccl.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/pynccl_wrapper.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/pynccl_wrapper.cpython-310.pyc new file mode 100644 index 0000000..19d5c14 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/pynccl_wrapper.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/shm_broadcast.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/shm_broadcast.cpython-310.pyc new file mode 100644 index 0000000..b63c55a Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/shm_broadcast.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/__pycache__/tpu_communicator.cpython-310.pyc b/vllm/distributed/device_communicators/__pycache__/tpu_communicator.cpython-310.pyc new file mode 100644 index 0000000..7c438a6 Binary files /dev/null and b/vllm/distributed/device_communicators/__pycache__/tpu_communicator.cpython-310.pyc differ diff --git a/vllm/distributed/device_communicators/cuda_wrapper.py b/vllm/distributed/device_communicators/cuda_wrapper.py new file mode 100644 index 0000000..d5a5338 --- /dev/null +++ b/vllm/distributed/device_communicators/cuda_wrapper.py @@ -0,0 +1,172 @@ +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("cudaIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + assert so_file is not None, \ + "libcudart is not loaded in the current process" + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py new file mode 100644 index 0000000..7de5b05 --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -0,0 +1,292 @@ +from contextlib import contextmanager +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.distributed.device_communicators.custom_all_reduce_utils import ( + gpu_p2p_access_check) +from vllm.distributed.parallel_state import in_the_same_node_as +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import cuda_device_count_stateless + +try: + ops.meta_size() + custom_ar = True +except Exception: + # For AMD GPUs and CPUs + custom_ar = False + +logger = init_logger(__name__) + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if envs.VLLM_SKIP_P2P_CHECK: + logger.info( + "Skipping P2P check and trusting the driver's P2P report.") + return torch.cuda.can_device_access_peer(rank, i) + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + assert current_platform.is_cuda() + from vllm.platforms.cuda import CudaPlatform + cuda_platform: CudaPlatform = current_platform + full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + if not _can_p2p(rank, world_size): + logger.warning( + "Custom allreduce is disabled because your platform lacks " + "GPU P2P capability or P2P test failed. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.") + return + + self.disabled = False + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + self.meta = torch.zeros(ops.meta_size() + max_size, + dtype=torch.uint8, + device=self.device) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + handles, offsets = self._get_ipc_meta(self.meta) + self.full_nvlink = full_nvlink + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) + self.register_buffer(self.buffer) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def _get_ipc_meta(self, inp: torch.Tensor): + data = inp.untyped_storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + ops.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + # when custom allreduce is disabled, this will be None + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce_reg(input) + else: + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + return self.all_reduce_unreg(input) + + return None + + def close(self): + if not self.disabled and self._ptr: + ops.dispose(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py new file mode 100644 index 0000000..983e772 --- /dev/null +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -0,0 +1,255 @@ +import ctypes +import json +import os +import pickle +import subprocess +import sys +import tempfile +from itertools import product +from typing import Dict, List, Optional, Sequence + +import torch.distributed as dist +import torch.multiprocessing as mp + +import vllm.envs as envs +from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary +from vllm.logger import init_logger +from vllm.utils import (cuda_device_count_stateless, + update_environment_variables) + +logger = init_logger(__name__) + + +def producer(batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer(batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process(target=producer, + args=(batch_src, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_tgt = smp.Process(target=consumer, + args=(batch_tgt, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: List[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + logger.warning( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", src, tgt) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + envs.VLLM_CACHE_ROOT, + f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + from vllm.distributed.parallel_state import get_world_group + if ((not is_distributed or get_world_group().local_rank == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache in %s", path) + cache: Dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps( + (batch_src, batch_tgt, output_file.name)) + returned = subprocess.run([sys.executable, __file__], + input=input_bytes, + capture_output=True) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + logger.info("reading GPU P2P access cache from %s", path) + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py new file mode 100644 index 0000000..7319566 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl.py @@ -0,0 +1,170 @@ +from contextlib import contextmanager +from typing import Optional, Union + +# ===================== import region ===================== +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup, ReduceOp + +from vllm.distributed.device_communicators.pynccl_wrapper import ( + NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, + ncclRedOpTypeEnum, ncclUniqueId) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class PyNcclCommunicator: + + def __init__( + self, + group: ProcessGroup, + device: Union[int, str, torch.device], + library_path: Optional[str] = None, + ): + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the PyNcclCommunicator to. If None, + it will be bind to f"cuda:{local_rank}". + library_path: the path to the NCCL library. If None, it will + use the default library path. + It is the caller's responsibility to make sure each communicator + is bind to a unique device. + """ + assert dist.is_initialized() + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "PyNcclCommunicator should be attached to a non-NCCL group.") + self.group = group + # note: this rank is the rank in the group + self.rank = dist.get_rank(group) + self.world_size = dist.get_world_size(group) + + # if world_size == 1, no need to create communicator + if self.world_size == 1: + self.available = False + self.disabled = True + self.stream = None + return + try: + self.nccl = NCCLLibrary(library_path) + except Exception: + # disable because of missing NCCL library + # e.g. in a non-GPU environment + self.available = False + self.disabled = True + self.stream = None + return + + self.available = True + self.disabled = False + + logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion()) + + if self.rank == 0: + # get the unique id from NCCL + self.unique_id = self.nccl.ncclGetUniqueId() + else: + # construct an empty unique id + self.unique_id = ncclUniqueId() + tensor = torch.ByteTensor(list(self.unique_id.internal)) + ranks = dist.get_process_group_ranks(group) + # arg `src` in `broadcast` is the global rank + dist.broadcast(tensor, src=ranks[0], group=group) + byte_list = tensor.tolist() + for i, byte in enumerate(byte_list): + self.unique_id.internal[i] = byte + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + # nccl communicator and stream will use this device + # `torch.cuda.device` is a context manager that changes the + # current cuda device to the specified one + with torch.cuda.device(device): + self.comm: ncclComm_t = self.nccl.ncclCommInitRank( + self.world_size, self.unique_id, self.rank) + self.stream = torch.cuda.Stream() + + # A small all_reduce for warmup. + data = torch.zeros(1, device=device) + self.all_reduce(data) + self.stream.synchronize() + del data + + # by default it is disabled, e.g. in profiling models and prefill phase. + # to use it, use under `with obj.change_state(enable=True)`, usually + # when we are using CUDA graph. + self.disabled = True + + def all_reduce(self, + tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), + buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + + def send(self, tensor: torch.Tensor, dst: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), dst, + self.comm, cudaStream_t(stream.cuda_stream)) + + def recv(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py new file mode 100644 index 0000000..7619c98 --- /dev/null +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -0,0 +1,278 @@ +# This file is a pure Python wrapper for the NCCL library. +# The main purpose is to use NCCL combined with CUDA graph. +# Before writing this script, we tried the following approach: +# 1. We tried to use `cupy`, it calls NCCL correctly, but `cupy` itself +# often gets stuck when initializing the NCCL communicator. +# 2. We tried to use `torch.distributed`, but `torch.distributed.all_reduce` +# contains many other potential cuda APIs, that are not allowed during +# capturing the CUDA graph. For further details, please check +# https://discuss.pytorch.org/t/pytorch-cudagraph-with-nccl-operation-failed/ . +# +# Another rejected idea is to write a C/C++ binding for NCCL. It is usually +# doable, but we often encounter issues related with nccl versions, and need +# to switch between different versions of NCCL. See +# https://github.com/NVIDIA/nccl/issues/1234 for more details. +# A C/C++ binding is not flexible enough to handle this. It requires +# recompilation of the code every time we want to switch between different +# versions. This current implementation, with a **pure** Python wrapper, is +# more flexible. We can easily switch between different versions of NCCL by +# changing the environment variable `VLLM_NCCL_SO_PATH`, or the `so_file` +# variable in the code. + +import ctypes +import platform +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +import torch +from torch.distributed import ReduceOp + +from vllm.logger import init_logger +from vllm.utils import find_nccl_library + +logger = init_logger(__name__) + +# === export types and functions from nccl to Python === +# for the original nccl definition, please check +# https://github.com/NVIDIA/nccl/blob/master/src/nccl.h.in + +ncclResult_t = ctypes.c_int +ncclComm_t = ctypes.c_void_p + + +class ncclUniqueId(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +cudaStream_t = ctypes.c_void_p +buffer_type = ctypes.c_void_p + +ncclDataType_t = ctypes.c_int + + +class ncclDataTypeEnum: + ncclInt8 = 0 + ncclChar = 0 + ncclUint8 = 1 + ncclInt32 = 2 + ncclInt = 2 + ncclUint32 = 3 + ncclInt64 = 4 + ncclUint64 = 5 + ncclFloat16 = 6 + ncclHalf = 6 + ncclFloat32 = 7 + ncclFloat = 7 + ncclFloat64 = 8 + ncclDouble = 8 + ncclBfloat16 = 9 + ncclNumTypes = 10 + + @classmethod + def from_torch(cls, dtype: torch.dtype) -> int: + if dtype == torch.int8: + return cls.ncclInt8 + if dtype == torch.uint8: + return cls.ncclUint8 + if dtype == torch.int32: + return cls.ncclInt32 + if dtype == torch.int64: + return cls.ncclInt64 + if dtype == torch.float16: + return cls.ncclFloat16 + if dtype == torch.float32: + return cls.ncclFloat32 + if dtype == torch.float64: + return cls.ncclFloat64 + if dtype == torch.bfloat16: + return cls.ncclBfloat16 + raise ValueError(f"Unsupported dtype: {dtype}") + + +ncclRedOp_t = ctypes.c_int + + +class ncclRedOpTypeEnum: + ncclSum = 0 + ncclProd = 1 + ncclMax = 2 + ncclMin = 3 + ncclAvg = 4 + ncclNumOps = 5 + + @classmethod + def from_torch(cls, op: ReduceOp) -> int: + if op == ReduceOp.SUM: + return cls.ncclSum + if op == ReduceOp.PRODUCT: + return cls.ncclProd + if op == ReduceOp.MAX: + return cls.ncclMax + if op == ReduceOp.MIN: + return cls.ncclMin + if op == ReduceOp.AVG: + return cls.ncclAvg + raise ValueError(f"Unsupported op: {op}") + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +class NCCLLibrary: + exported_functions = [ + # const char* ncclGetErrorString(ncclResult_t result) + Function("ncclGetErrorString", ctypes.c_char_p, [ncclResult_t]), + # ncclResult_t ncclGetVersion(int *version); + Function("ncclGetVersion", ncclResult_t, + [ctypes.POINTER(ctypes.c_int)]), + # ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); + Function("ncclGetUniqueId", ncclResult_t, + [ctypes.POINTER(ncclUniqueId)]), + # ncclResult_t ncclCommInitRank( + # ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); + # note that ncclComm_t is a pointer type, so the first argument + # is a pointer to a pointer + Function("ncclCommInitRank", ncclResult_t, [ + ctypes.POINTER(ncclComm_t), ctypes.c_int, ncclUniqueId, + ctypes.c_int + ]), + # ncclResult_t ncclAllReduce( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllReduce", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclSend( + # const void* sendbuff, size_t count, ncclDataType_t datatype, + # int dest, ncclComm_t comm, cudaStream_t stream); + Function("ncclSend", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclRecv( + # void* recvbuff, size_t count, ncclDataType_t datatype, + # int src, ncclComm_t comm, cudaStream_t stream); + Function("ncclRecv", ncclResult_t, [ + buffer_type, ctypes.c_size_t, ncclDataType_t, ctypes.c_int, + ncclComm_t, cudaStream_t + ]), + + # be cautious! this is a collective call, it will block until all + # processes in the communicator have called this function. + # because Python object destruction can happen in random order, + # it is better not to call it at all. + # ncclResult_t ncclCommDestroy(ncclComm_t comm); + Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + + so_file = so_file or find_nccl_library() + + try: + if so_file not in NCCLLibrary.path_to_dict_mapping: + lib = ctypes.CDLL(so_file) + NCCLLibrary.path_to_library_cache[so_file] = lib + self.lib = NCCLLibrary.path_to_library_cache[so_file] + except Exception as e: + logger.error( + "Failed to load NCCL library from %s ." + "It is expected if you are not running on NVIDIA/AMD GPUs." + "Otherwise, the nccl library might not exist, be corrupted " + "or it does not support the current platform %s." + "If you already have the library, please set the " + "environment variable VLLM_NCCL_SO_PATH" + " to point to the correct nccl library path.", so_file, + platform.platform()) + raise e + + if so_file not in NCCLLibrary.path_to_dict_mapping: + _funcs: Dict[str, Any] = {} + for func in NCCLLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + NCCLLibrary.path_to_dict_mapping[so_file] = _funcs + self._funcs = NCCLLibrary.path_to_dict_mapping[so_file] + + def ncclGetErrorString(self, result: ncclResult_t) -> str: + return self._funcs["ncclGetErrorString"](result).decode("utf-8") + + def NCCL_CHECK(self, result: ncclResult_t) -> None: + if result != 0: + error_str = self.ncclGetErrorString(result) + raise RuntimeError(f"NCCL error: {error_str}") + + def ncclGetVersion(self) -> str: + version = ctypes.c_int() + self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version))) + version_str = str(version.value) + # something like 21903 --> "2.19.3" + major = version_str[0].lstrip("0") + minor = version_str[1:3].lstrip("0") + patch = version_str[3:].lstrip("0") + return f"{major}.{minor}.{patch}" + + def ncclGetUniqueId(self) -> ncclUniqueId: + unique_id = ncclUniqueId() + self.NCCL_CHECK(self._funcs["ncclGetUniqueId"]( + ctypes.byref(unique_id))) + return unique_id + + def ncclCommInitRank(self, world_size: int, unique_id: ncclUniqueId, + rank: int) -> ncclComm_t: + comm = ncclComm_t() + self.NCCL_CHECK(self._funcs["ncclCommInitRank"](ctypes.byref(comm), + world_size, unique_id, + rank)) + return comm + + def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllReduce"](sendbuff, recvbuff, count, + datatype, op, comm, + stream)) + + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, + dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, + dest, comm, stream)) + + def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, + src: int, comm: ncclComm_t, stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, + comm, stream)) + + def ncclCommDestroy(self, comm: ncclComm_t) -> None: + self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + + +__all__ = [ + "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", + "ncclComm_t", "cudaStream_t", "buffer_type" +] diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py new file mode 100644 index 0000000..7d526b2 --- /dev/null +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -0,0 +1,491 @@ +import pickle +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from multiprocessing import shared_memory +from typing import List, Optional +from unittest.mock import patch + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address + +VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL + +# time to wait if the queue is full or empty +# if we sleep for too short, it will consume too much CPU +# if we sleep for too long, it will slow down the writer/reader +# 0.1 us is a good balance +RINGBUFFER_SLEEP_INTERVAL = 1e-7 + +logger = init_logger(__name__) + + +class ShmRingBuffer: + + def __init__(self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """# noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = (self.max_chunk_bytes + + self.metadata_size) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer) + # initialize the metadata section to 0 + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + try: + self.shared_memory = shared_memory.SharedMemory(name=name) + assert ( + self.shared_memory.size == self.total_bytes_of_buffer) + except FileNotFoundError: + # we might deserialize the object in a different node + # in this case, this object is not used, + # and we should suppress the error + pass + + def __reduce__(self): + return ( + self.__class__, + (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name), + ) + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +@dataclass +class Handle: + connect_ip: str + local_reader_ranks: List[int] = field(default_factory=list) + + buffer: Optional[ShmRingBuffer] = None + local_subscribe_port: Optional[int] = None + remote_subscribe_port: Optional[int] = None + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[List[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + if connect_ip is None: + connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_port = get_open_port() + socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" + logger.debug("Binding to %s", socket_addr) + self.local_socket.bind(socket_addr) + + self.current_idx = 0 + + else: + self.buffer = None # type: ignore + local_subscribe_port = None + self.local_socket = None + self.current_idx = -1 + + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + + else: + remote_subscribe_port = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + + self.handle = Handle( + connect_ip=connect_ip, + local_reader_ranks=local_reader_ranks, + buffer=self.buffer, + local_subscribe_port=local_subscribe_port, + remote_subscribe_port=remote_subscribe_port, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer is not None + self.buffer = handle.buffer + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) + + self.remote_socket = None + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + # wait for subscription messages from all local readers + self.local_socket.recv() + if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + # wait for subscription messages from all remote readers + self.remote_socket.recv() + if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working + self.remote_socket.send(b"READY") + elif self._is_local_reader: + # wait for the writer to send a message + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + # wait for the writer to send a message + recv = self.remote_socket.recv() + assert recv == b"READY" + + @contextmanager + def acquire_write(self): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # wait for a while + time.sleep(RINGBUFFER_SLEEP_INTERVAL) + + # if we wait for a long time, we should warn the user + if (time.monotonic() - start_time > + VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + @contextmanager + def acquire_read(self): + assert self._is_local_reader, "Only readers can acquire read" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.local_reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # wait for a while + time.sleep(RINGBUFFER_SLEEP_INTERVAL) + + # if we wait for a long time, we should warn the user + if (time.monotonic() - start_time > + VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.local_reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + def enqueue(self, obj): + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write() as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write() as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) + + def dequeue(self): + if self._is_local_reader: + with self.acquire_read() as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + recv = self.local_socket.recv() + obj = pickle.loads(recv) + elif self._is_remote_reader: + recv = self.remote_socket.recv() + obj = pickle.loads(recv) + else: + raise RuntimeError("Only readers can dequeue") + return obj + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + @staticmethod + def create_from_process_group(pg: ProcessGroup, + max_chunk_bytes, + max_chunks, + writer_rank=0) -> "MessageQueue": + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + + from vllm.distributed.parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue + if group_rank == writer_rank: + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + dist.broadcast_object_list([handle], + src=global_ranks[writer_rank], + group=pg) + else: + recv = [None] + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) + handle = recv[0] # type: ignore + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py new file mode 100644 index 0000000..765a0f9 --- /dev/null +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -0,0 +1,61 @@ +import os + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from vllm.platforms import current_platform + +if current_platform.is_tpu(): + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + + from vllm.executor import ray_utils + + +class TpuCommunicator: + + def __init__(self, group: ProcessGroup): + if not current_platform.is_tpu(): + self.disabled = True + return + self.disabled = False + + # NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node + # must be used together. Therefore, the local rank and world size can + # be simply calculated as follows. + global_rank = dist.get_rank(group) + global_world_size = dist.get_world_size(group) + + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + + local_world_size = global_world_size // num_nodes + local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + + pjrt.initialize_multiprocess(local_rank, local_world_size) + xr._init_world_size_ordinal() + + def all_reduce(self, x: torch.Tensor) -> torch.Tensor: + return xm.all_reduce(xm.REDUCE_SUM, x) + + def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor: + assert dim == -1, "TPUs only support dim=-1 for all-gather." + return xm.all_gather(x, dim=dim) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py new file mode 100644 index 0000000..da1ad93 --- /dev/null +++ b/vllm/distributed/parallel_state.py @@ -0,0 +1,1258 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import supports_custom_op +from ixformer.distributed import all_reduce +# from ixformer.distributed import all_reduce, send, recv, gather +import ixformer.distributed as ixfd +import vllm._custom_ops as ops + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + # looks like Python 3.8 does not understand `ReferenceType` + _groups[group.unique_name] = weakref.ref(group) # type: ignore + + +if supports_custom_op(): + + @torch.library.custom_op("vllm::inplace_all_reduce", + mutates_args=["tensor"]) + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce_in_place(tensor) + + @inplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> None: + return + + @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + def outplace_all_reduce(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + @outplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + use_tpu_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + world_size: Optional[str] = None, + ): + use_pynccl = False # we do not use pynccl + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + self.is_multi_node = False if world_size is None else bool(world_size > torch.cuda.device_count()) + + for ranks in group_ranks: + # if self.is_multi_node: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # else: + # device_group = ixfd.new_group( + # ranks, shmsize=1024 if world_size<=1 else None, backend=torch_distributed_backend,) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + if current_platform.is_cuda_alike(): + self.device = torch.device(f"cuda:{local_rank}") + else: + self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator + + # lazy import to avoid documentation build error + from vllm.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce) + from vllm.distributed.device_communicators.pynccl import ( + PyNcclCommunicator) + + self.pynccl_comm: Optional[PyNcclCommunicator] = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + from vllm.distributed.device_communicators.tpu_communicator import ( + TpuCommunicator) + self.tpu_communicator: Optional[TpuCommunicator] = None + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + + from vllm.distributed.device_communicators.shm_broadcast import ( + MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext( + ) if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if not supports_custom_op(): + self._all_reduce_in_place(input_) + return input_ + + if self.tpu_communicator is not None and \ + not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self.tpu_communicator.all_reduce(input_) + + if self.ca_comm is not None and \ + not self.ca_comm.disabled and \ + self.ca_comm.should_custom_ar(input_): + return torch.ops.vllm.outplace_all_reduce( + input_, group_name=self.unique_name) + else: + torch.ops.vllm.inplace_all_reduce(input_, + group_name=self.unique_name) + return input_ + + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + ca_comm = self.ca_comm + assert ca_comm is not None + assert not ca_comm.disabled + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + + def _all_reduce_in_place(self, input_: torch.Tensor) -> None: + pynccl_comm = self.pynccl_comm + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) + elif input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + else: + # if self.is_multi_node: + # torch.distributed.all_reduce(input_, group=self.device_group) + # else: + # all_reduce(input_, group=self.device_group, async_op=True) + from ixformer.contrib.torch.extension.ixformer_torch.distributed import create_ixformer_group_from_pg + _IXFORMER_TENSOR_MODEL_PARALLEL_GROUP = create_ixformer_group_from_pg(self.device_group) + all_reduce(input_, group=_IXFORMER_TENSOR_MODEL_PARALLEL_GROUP, async_op=True) + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # # Gather. + # if self.is_multi_node: + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group, + ) + # else: + # gather(input_, + # gather_list, + # dst=self.ranks[dst], + # group=self.device_group, + # async_op=True) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + # if self.is_multi_node: + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + # else: + # ops.broadcast(input_, + # src=self.ranks[src], + # group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + # if self.is_multi_node: + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + # else: + # handle = ops.broadcast(tensor, + # src=self.ranks[src], + # group=group, + # async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + # if self.is_multi_node: + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + # else: + # handle = ops.broadcast( + # tensor, + # src=self.ranks[src], + # group=group, + # async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + # if self.is_multi_node: + torch.distributed.send( + tensor, + dst=self.ranks[dst], + group=group) + # else: + # send(tensor, + # dst=self.ranks[dst], + # group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + else: + # if self.is_multi_node: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + # else: + # recv(tensor, + # src=self.ranks[src], + # group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + # if self.is_multi_node: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + # else: + # send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + # if self.is_multi_node: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + # else: + # recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, ("world group is not initialized") + return _WORLD + + +def init_world_group(ranks: List[int], local_rank: int, + backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=False, + use_tpu_communicator=False, + group_name="world", + world_size=len(ranks), + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + world_size: Optional[int] = None, +) -> GroupCoordinator: + if use_custom_allreduce is None: + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=True, + use_custom_allreduce=use_custom_allreduce, + use_tpu_communicator=True, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + world_size=world_size, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, ("tensor model parallel group is not initialized") + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + with get_tp_group().graph_capture() as context, get_pp_group( + ).graph_capture(context): + yield context + + +logger = init_logger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + init_method=distributed_init_method, + world_size=world_size, + rank=rank) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + local_rank = envs.LOCAL_RANK + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = (world_size // + tensor_model_parallel_size) + global _TP + assert _TP is None, ("tensor model parallel group is already initialized") + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + world_size=world_size) + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = (world_size // + pipeline_model_parallel_size) + global _PP + assert _PP is None, ( + "pipeline model parallel group is already initialized") + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="pp", + world_size=world_size) + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size, backend) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert (pp_world_size == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return (_TP is not None and _PP is not None) + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], + src=ranks[source_rank], + group=pg) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=ranks[source_rank], + group=pg) + name = recv[0] + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + torch.distributed.barrier(group=pg) + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + + return [x == 1 for x in is_in_the_same_node.tolist()] diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py new file mode 100644 index 0000000..8c94ef8 --- /dev/null +++ b/vllm/distributed/utils.py @@ -0,0 +1,86 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from typing import Sequence, Tuple + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +def get_pp_indices(num_hidden_layers: int, pp_rank: int, + pp_size: int) -> Tuple[int, int]: + """Try to evenly distribute layers across partitions. + If the number of layers is not divisible by the number of partitions, + the last partition will have the remaining layers. + """ + partition_list_str = envs.VLLM_PP_LAYER_PARTITION + if partition_list_str is not None: + try: + partitions = [ + int(layer) for layer in partition_list_str.split(",") + ] + except ValueError as err: + raise ValueError("Invalid partition string: {}".format( + partition_list_str)) from err + if len(partitions) != pp_size: + raise ValueError(f"{len(partitions)=} does not match {pp_size=}.") + if sum(partitions) != num_hidden_layers: + raise ValueError( + f"{sum(partitions)=} does not match {num_hidden_layers=}.") + start_layer = sum(partitions[:pp_rank]) + end_layer = start_layer + partitions[pp_rank] + else: + layers_per_partition = num_hidden_layers // pp_size + start_layer = pp_rank * layers_per_partition + end_layer = start_layer + layers_per_partition + + if pp_rank == pp_size - 1: + end_layer = num_hidden_layers + + return (start_layer, end_layer) diff --git a/vllm/engine/__init__.py b/vllm/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/__pycache__/__init__.cpython-310.pyc b/vllm/engine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..1312f8c Binary files /dev/null and b/vllm/engine/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/arg_utils.cpython-310.pyc b/vllm/engine/__pycache__/arg_utils.cpython-310.pyc new file mode 100644 index 0000000..cd5d308 Binary files /dev/null and b/vllm/engine/__pycache__/arg_utils.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc b/vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc new file mode 100644 index 0000000..9d8f931 Binary files /dev/null and b/vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/async_timeout.cpython-310.pyc b/vllm/engine/__pycache__/async_timeout.cpython-310.pyc new file mode 100644 index 0000000..941df68 Binary files /dev/null and b/vllm/engine/__pycache__/async_timeout.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/llm_engine.cpython-310.pyc b/vllm/engine/__pycache__/llm_engine.cpython-310.pyc new file mode 100644 index 0000000..38aed70 Binary files /dev/null and b/vllm/engine/__pycache__/llm_engine.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/metrics.cpython-310.pyc b/vllm/engine/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000..0964a20 Binary files /dev/null and b/vllm/engine/__pycache__/metrics.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/metrics_types.cpython-310.pyc b/vllm/engine/__pycache__/metrics_types.cpython-310.pyc new file mode 100644 index 0000000..c952429 Binary files /dev/null and b/vllm/engine/__pycache__/metrics_types.cpython-310.pyc differ diff --git a/vllm/engine/__pycache__/protocol.cpython-310.pyc b/vllm/engine/__pycache__/protocol.cpython-310.pyc new file mode 100644 index 0000000..e01213f Binary files /dev/null and b/vllm/engine/__pycache__/protocol.cpython-310.pyc differ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py new file mode 100644 index 0000000..4825f0f --- /dev/null +++ b/vllm/engine/arg_utils.py @@ -0,0 +1,1143 @@ +import argparse +import dataclasses +import json +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional, + Tuple, Type, Union) + +import torch + +import vllm.envs as envs +from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, + DeviceConfig, EngineConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig) +from vllm.executor.executor_base import ExecutorBase +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.config import ( + maybe_register_config_serialize_by_value) +from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import FlexibleArgumentParser + +if TYPE_CHECKING: + from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup + +logger = init_logger(__name__) + +ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"] + +DEVICE_OPTIONS = [ + "auto", + "cuda", + "neuron", + "cpu", + "openvino", + "tpu", + "xpu", +] + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +def nullable_kvs(val: str) -> Optional[Mapping[str, int]]: + """Parses a string containing comma separate key [str] to value [int] + pairs into a dictionary. + + Args: + val: String value to be parsed. + + Returns: + Dictionary with parsed values. + """ + if len(val) == 0: + return None + + out_dict: Dict[str, int] = {} + for item in val.split(","): + kv_parts = [part.lower().strip() for part in item.split("=")] + if len(kv_parts) != 2: + raise argparse.ArgumentTypeError( + "Each item should be in the form KEY=VALUE") + key, value = kv_parts + + try: + parsed_value = int(value) + except ValueError as exc: + msg = f"Failed to parse value of item {key}={value}" + raise argparse.ArgumentTypeError(msg) from exc + + if key in out_dict and out_dict[key] != parsed_value: + raise argparse.ArgumentTypeError( + f"Conflicting values specified for key: {key}") + out_dict[key] = parsed_value + + return out_dict + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str = 'facebook/opt-125m' + served_model_name: Optional[Union[str, List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + config_format: str = 'auto' + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None + seed: int = 0 + max_model_len: Optional[int] = None + worker_use_ray: bool = False + # Note: Specifying a custom executor backend by passing a class + # is intended for expert use only. The API may change without + # notice. + distributed_executor_backend: Optional[Union[str, + Type[ExecutorBase]]] = None + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None + block_size: int = 16 + enable_prefix_caching: bool = False + disable_sliding_window: bool = False + use_v2_block_manager: bool = True + swap_space: float = 4 # GiB + cpu_offload_gb: float = 0 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_logprobs: int = 20 # Default value for OpenAI Chat Completions API + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: Optional[bool] = None + max_context_len_to_capture: Optional[int] = None + max_seq_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + tokenizer_pool_size: int = 0 + # Note: Specifying a tokenizer pool by passing a class + # is intended for expert use only. The API may change without + # notice. + tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray" + tokenizer_pool_extra_config: Optional[dict] = None + limit_mm_per_prompt: Optional[Mapping[str, int]] = None + enable_lora: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 0 + fully_sharded_loras: bool = False + lora_extra_vocab_size: int = 256 + long_lora_scaling_factors: Optional[Tuple[float]] = None + lora_dtype: Optional[Union[str, torch.dtype]] = 'auto' + max_cpu_loras: Optional[int] = None + device: str = 'auto' + num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = True + ray_workers_use_nsight: bool = False + num_gpu_blocks_override: Optional[int] = None + num_lookahead_slots: int = 0 + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + preemption_mode: Optional[str] = None + + scheduler_delay_factor: float = 0.0 + enable_chunked_prefill: Optional[bool] = None + + guided_decoding_backend: str = 'outlines' + # Speculative decoding configuration. + speculative_model: Optional[str] = None + speculative_model_quantization: Optional[str] = None + speculative_draft_tensor_parallel_size: Optional[int] = None + num_speculative_tokens: Optional[int] = None + speculative_disable_mqa_scorer: Optional[bool] = False + speculative_max_model_len: Optional[int] = None + speculative_disable_by_batch_size: Optional[int] = None + ngram_prompt_lookup_max: Optional[int] = None + ngram_prompt_lookup_min: Optional[int] = None + spec_decoding_acceptance_method: str = 'rejection_sampler' + typical_acceptance_sampler_posterior_threshold: Optional[float] = None + typical_acceptance_sampler_posterior_alpha: Optional[float] = None + qlora_adapter_name_or_path: Optional[str] = None + disable_logprobs_during_spec_decoding: Optional[bool] = None + + otlp_traces_endpoint: Optional[str] = None + collect_detailed_traces: Optional[str] = None + disable_async_output_proc: bool = False + override_neuron_config: Optional[Dict[str, Any]] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None + scheduling_policy: Literal["fcfs", "priority"] = "fcfs" + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # Model arguments + parser.add_argument( + '--model', + type=str, + default=EngineArgs.model, + help='Name or path of the huggingface model to use.') + parser.add_argument( + '--tokenizer', + type=nullable_str, + default=EngineArgs.tokenizer, + help='Name or path of the huggingface tokenizer to use. ' + 'If unspecified, model name or path will be used.') + parser.add_argument( + '--skip-tokenizer-init', + action='store_true', + help='Skip initialization of tokenizer and detokenizer') + parser.add_argument( + '--revision', + type=nullable_str, + default=None, + help='The specific model version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument( + '--code-revision', + type=nullable_str, + default=None, + help='The specific revision to use for the model code on ' + 'Hugging Face Hub. It can be a branch name, a tag name, or a ' + 'commit id. If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-revision', + type=nullable_str, + default=None, + help='Revision of the huggingface tokenizer to use. ' + 'It can be a branch name, a tag name, or a commit id. ' + 'If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow', 'mistral'], + help='The tokenizer mode.\n\n* "auto" will use the ' + 'fast tokenizer if available.\n* "slow" will ' + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer.') + parser.add_argument('--trust-remote-code', + action='store_true', + help='Trust remote code from huggingface.') + parser.add_argument('--download-dir', + type=nullable_str, + default=EngineArgs.download_dir, + help='Directory to download and load the weights, ' + 'default to the default cache dir of ' + 'huggingface.') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=[f.value for f in LoadFormat], + help='The format of the model weights to load.\n\n' + '* "auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available.\n' + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading.\n' + '* "dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.\n' + '* "tensorizer" will load the weights using tensorizer from ' + 'CoreWeave. See the Tensorize vLLM Model script in the Examples ' + 'section for more information.\n' + '* "bitsandbytes" will load the weights using bitsandbytes ' + 'quantization.\n') + parser.add_argument( + '--config-format', + default=EngineArgs.config_format, + choices=[f.value for f in ConfigFormat], + help='The format of the model config to load.\n\n' + '* "auto" will try to load the config in hf format ' + 'if available else it will try to load in mistral format ') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='Data type for model weights and activations.\n\n' + '* "auto" will use FP16 precision for FP32 and FP16 models, and ' + 'BF16 precision for BF16 models.\n' + '* "half" for FP16. Recommended for AWQ quantization.\n' + '* "float16" is the same as "half".\n' + '* "bfloat16" for a balance between precision and range.\n' + '* "float" is shorthand for FP32 precision.\n' + '* "float32" for FP32 precision.') + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default=EngineArgs.kv_cache_dtype, + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=nullable_str, + default=None, + help='Path to the JSON file containing the KV cache ' + 'scaling factors. This should generally be supplied, when ' + 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' + 'default to 1.0, which may cause accuracy issues. ' + 'FP8_E5M2 (without scaling) is only supported on cuda version' + 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' + 'supported for common inference criteria.') + parser.add_argument('--max-model-len', + type=int, + default=EngineArgs.max_model_len, + help='Model context length. If unspecified, will ' + 'be automatically derived from the model config.') + parser.add_argument( + '--guided-decoding-backend', + type=str, + default='outlines', + choices=['outlines', 'lm-format-enforcer'], + help='Which engine will be used for guided decoding' + ' (JSON schema / regex etc) by default. Currently support ' + 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/noamgat/lm-format-enforcer.' + ' Can be overridden per request via guided_decoding_backend' + ' parameter.') + # Parallel arguments + parser.add_argument( + '--distributed-executor-backend', + choices=['ray', 'mp'], + default=EngineArgs.distributed_executor_backend, + help='Backend to use for distributed serving. When more than 1 GPU ' + 'is used, will be automatically set to "ray" if installed ' + 'or "mp" (multiprocessing) otherwise.') + parser.add_argument( + '--worker-use-ray', + action='store_true', + help='Deprecated, use --distributed-executor-backend=ray.') + parser.add_argument('--pipeline-parallel-size', + '-pp', + type=int, + default=EngineArgs.pipeline_parallel_size, + help='Number of pipeline stages.') + parser.add_argument('--tensor-parallel-size', + '-tp', + type=int, + default=EngineArgs.tensor_parallel_size, + help='Number of tensor parallel replicas.') + parser.add_argument( + '--max-parallel-loading-workers', + type=int, + default=EngineArgs.max_parallel_loading_workers, + help='Load model sequentially in multiple batches, ' + 'to avoid RAM OOM when using tensor ' + 'parallel and large models.') + parser.add_argument( + '--ray-workers-use-nsight', + action='store_true', + help='If specified, use nsight to profile Ray workers.') + # KV cache arguments + parser.add_argument('--block-size', + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32], + help='Token block size for contiguous chunks of ' + 'tokens. This is ignored on neuron devices and ' + 'set to max-model-len') + + parser.add_argument('--enable-prefix-caching', + action='store_true', + help='Enables automatic prefix caching.') + parser.add_argument('--disable-sliding-window', + action='store_true', + help='Disables sliding window, ' + 'capping to sliding window size') + parser.add_argument( + '--use-v2-block-manager', + default=EngineArgs.use_v2_block_manager, + action='store_true', + help='Use BlockSpaceMangerV2. By default this is set to True. ' + 'Set to False to use BlockSpaceManagerV1') + parser.add_argument( + '--num-lookahead-slots', + type=int, + default=EngineArgs.num_lookahead_slots, + help='Experimental scheduling config necessary for ' + 'speculative decoding. This will be replaced by ' + 'speculative config in the future; it is present ' + 'to enable correctness tests until then.') + + parser.add_argument('--seed', + type=int, + default=EngineArgs.seed, + help='Random seed for operations.') + parser.add_argument('--swap-space', + type=float, + default=EngineArgs.swap_space, + help='CPU swap space size (GiB) per GPU.') + parser.add_argument( + '--cpu-offload-gb', + type=float, + default=0, + help='The space in GiB to offload to CPU, per GPU. ' + 'Default is 0, which means no offloading. Intuitively, ' + 'this argument can be seen as a virtual way to increase ' + 'the GPU memory size. For example, if you have one 24 GB ' + 'GPU and set this to 10, virtually you can think of it as ' + 'a 34 GB GPU. Then you can load a 13B model with BF16 weight,' + 'which requires at least 26GB GPU memory. Note that this ' + 'requires fast CPU-GPU interconnect, as part of the model is' + 'loaded from CPU memory to GPU memory on the fly in each ' + 'model forward pass.') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='The fraction of GPU memory to be used for the model ' + 'executor, which can range from 0 to 1. For example, a value of ' + '0.5 would imply 50%% GPU memory utilization. If unspecified, ' + 'will use the default value of 0.9.') + parser.add_argument( + '--num-gpu-blocks-override', + type=int, + default=None, + help='If specified, ignore GPU profiling result and use this number' + 'of GPU blocks. Used for testing preemption.') + parser.add_argument('--max-num-batched-tokens', + type=int, + default=EngineArgs.max_num_batched_tokens, + help='Maximum number of batched tokens per ' + 'iteration.') + parser.add_argument('--max-num-seqs', + type=int, + default=EngineArgs.max_num_seqs, + help='Maximum number of sequences per iteration.') + parser.add_argument( + '--max-logprobs', + type=int, + default=EngineArgs.max_logprobs, + help=('Max number of log probs to return logprobs is specified in' + ' SamplingParams.')) + parser.add_argument('--disable-log-stats', + action='store_true', + help='Disable logging statistics.') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], + default=EngineArgs.quantization, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument('--rope-scaling', + default=None, + type=json.loads, + help='RoPE scaling configuration in JSON format. ' + 'For example, {"type":"dynamic","factor":2.0}') + parser.add_argument('--rope-theta', + default=None, + type=float, + help='RoPE theta. Use with `rope_scaling`. In ' + 'some cases, changing the RoPE theta improves the ' + 'performance of the scaled model.') + parser.add_argument('--enforce-eager', + action='store_true', + help='Always use eager-mode PyTorch. If False, ' + 'will use eager mode and CUDA graph in hybrid ' + 'for maximal performance and flexibility.') + parser.add_argument('--max-context-len-to-capture', + type=int, + default=EngineArgs.max_context_len_to_capture, + help='Maximum context length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + '(DEPRECATED. Use --max-seq-len-to-capture instead' + ')') + parser.add_argument('--max-seq-len-to-capture', + type=int, + default=EngineArgs.max_seq_len_to_capture, + help='Maximum sequence length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode. ' + 'Additionally for encoder-decoder models, if the ' + 'sequence length of the encoder input is larger ' + 'than this, we fall back to the eager mode.') + parser.add_argument('--disable-custom-all-reduce', + action='store_true', + default=EngineArgs.disable_custom_all_reduce, + help='See ParallelConfig.') + parser.add_argument('--tokenizer-pool-size', + type=int, + default=EngineArgs.tokenizer_pool_size, + help='Size of tokenizer pool to use for ' + 'asynchronous tokenization. If 0, will ' + 'use synchronous tokenization.') + parser.add_argument('--tokenizer-pool-type', + type=str, + default=EngineArgs.tokenizer_pool_type, + help='Type of tokenizer pool to use for ' + 'asynchronous tokenization. Ignored ' + 'if tokenizer_pool_size is 0.') + parser.add_argument('--tokenizer-pool-extra-config', + type=nullable_str, + default=EngineArgs.tokenizer_pool_extra_config, + help='Extra config for tokenizer pool. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary. Ignored if ' + 'tokenizer_pool_size is 0.') + + # Multimodal related configs + parser.add_argument( + '--limit-mm-per-prompt', + type=nullable_kvs, + default=EngineArgs.limit_mm_per_prompt, + # The default value is given in + # MultiModalRegistry.init_mm_limits_per_prompt + help=('For each multimodal plugin, limit how many ' + 'input instances to allow for each prompt. ' + 'Expects a comma-separated list of items, ' + 'e.g.: `image=16,video=2` allows a maximum of 16 ' + 'images and 2 videos per prompt. Defaults to 1 for ' + 'each modality.')) + parser.add_argument( + '--mm-processor-kwargs', + default=None, + type=json.loads, + help=('Overrides for the multimodal input mapping/processing,' + 'e.g., image processor. For example: {"num_crops": 4}.')) + + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='If True, enable handling of LoRA adapters.') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='Max number of LoRAs in a single batch.') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) + parser.add_argument( + '--long-lora-scaling-factors', + type=nullable_str, + default=EngineArgs.long_lora_scaling_factors, + help=('Specify multiple scaling factors (which can ' + 'be different from base model scaling factor ' + '- see eg. Long LoRA) to allow for multiple ' + 'LoRA adapters trained with those scaling ' + 'factors to be used at the same time. If not ' + 'specified, only adapters trained with the ' + 'base model scaling factor are allowed.')) + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') + parser.add_argument("--device", + type=str, + default=EngineArgs.device, + choices=DEVICE_OPTIONS, + help='Device type for vLLM execution.') + parser.add_argument('--num-scheduler-steps', + type=int, + default=1, + help=('Maximum number of forward steps per ' + 'scheduler call.')) + + parser.add_argument( + '--multi-step-stream-outputs', + action=StoreBoolean, + default=EngineArgs.multi_step_stream_outputs, + nargs="?", + const="True", + help='If False, then multi-step will stream outputs at the end ' + 'of all steps') + parser.add_argument( + '--scheduler-delay-factor', + type=float, + default=EngineArgs.scheduler_delay_factor, + help='Apply a delay (of delay factor multiplied by previous ' + 'prompt latency) before scheduling next prompt.') + parser.add_argument( + '--enable-chunked-prefill', + action=StoreBoolean, + default=EngineArgs.enable_chunked_prefill, + nargs="?", + const="True", + help='If set, the prefill requests can be chunked based on the ' + 'max_num_batched_tokens.') + + parser.add_argument( + '--speculative-model', + type=nullable_str, + default=EngineArgs.speculative_model, + help= + 'The name of the draft model to be used in speculative decoding.') + # Quantization settings for speculative model. + parser.add_argument( + '--speculative-model-quantization', + type=nullable_str, + choices=[*QUANTIZATION_METHODS, None], + default=EngineArgs.speculative_model_quantization, + help='Method used to quantize the weights of speculative model. ' + 'If None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument( + '--num-speculative-tokens', + type=int, + default=EngineArgs.num_speculative_tokens, + help='The number of speculative tokens to sample from ' + 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-disable-mqa-scorer', + action='store_true', + help= + 'If set to True, the MQA scorer will be disabled in speculative ' + ' and fall back to batch expansion') + parser.add_argument( + '--speculative-draft-tensor-parallel-size', + '-spec-draft-tp', + type=int, + default=EngineArgs.speculative_draft_tensor_parallel_size, + help='Number of tensor parallel replicas for ' + 'the draft model in speculative decoding.') + + parser.add_argument( + '--speculative-max-model-len', + type=int, + default=EngineArgs.speculative_max_model_len, + help='The maximum sequence length supported by the ' + 'draft model. Sequences over this length will skip ' + 'speculation.') + + parser.add_argument( + '--speculative-disable-by-batch-size', + type=int, + default=EngineArgs.speculative_disable_by_batch_size, + help='Disable speculative decoding for new incoming requests ' + 'if the number of enqueue requests is larger than this value.') + + parser.add_argument( + '--ngram-prompt-lookup-max', + type=int, + default=EngineArgs.ngram_prompt_lookup_max, + help='Max size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--ngram-prompt-lookup-min', + type=int, + default=EngineArgs.ngram_prompt_lookup_min, + help='Min size of window for ngram prompt lookup in speculative ' + 'decoding.') + + parser.add_argument( + '--spec-decoding-acceptance-method', + type=str, + default=EngineArgs.spec_decoding_acceptance_method, + choices=['rejection_sampler', 'typical_acceptance_sampler'], + help='Specify the acceptance method to use during draft token ' + 'verification in speculative decoding. Two types of acceptance ' + 'routines are supported: ' + '1) RejectionSampler which does not allow changing the ' + 'acceptance rate of draft tokens, ' + '2) TypicalAcceptanceSampler which is configurable, allowing for ' + 'a higher acceptance rate at the cost of lower quality, ' + 'and vice versa.') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-threshold', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_threshold, + help='Set the lower bound threshold for the posterior ' + 'probability of a token to be accepted. This threshold is ' + 'used by the TypicalAcceptanceSampler to make sampling decisions ' + 'during speculative decoding. Defaults to 0.09') + + parser.add_argument( + '--typical-acceptance-sampler-posterior-alpha', + type=float, + default=EngineArgs.typical_acceptance_sampler_posterior_alpha, + help='A scaling factor for the entropy-based threshold for token ' + 'acceptance in the TypicalAcceptanceSampler. Typically defaults ' + 'to sqrt of --typical-acceptance-sampler-posterior-threshold ' + 'i.e. 0.3') + + parser.add_argument( + '--disable-logprobs-during-spec-decoding', + action=StoreBoolean, + default=EngineArgs.disable_logprobs_during_spec_decoding, + nargs="?", + const="True", + help='If set to True, token log probabilities are not returned ' + 'during speculative decoding. If set to False, log probabilities ' + 'are returned according to the settings in SamplingParams. If ' + 'not specified, it defaults to True. Disabling log probabilities ' + 'during speculative decoding reduces latency by skipping logprob ' + 'calculation in proposal sampling, target sampling, and after ' + 'accepted tokens are determined.') + + parser.add_argument('--model-loader-extra-config', + type=nullable_str, + default=EngineArgs.model_loader_extra_config, + help='Extra config for model loader. ' + 'This will be passed to the model loader ' + 'corresponding to the chosen load_format. ' + 'This should be a JSON string that will be ' + 'parsed into a dictionary.') + parser.add_argument( + '--ignore-patterns', + action="append", + type=str, + default=[], + help="The pattern(s) to ignore when loading the model." + "Default to 'original/**/*' to avoid repeated loading of llama's " + "checkpoints.") + parser.add_argument( + '--preemption-mode', + type=str, + default=None, + help='If \'recompute\', the engine performs preemption by ' + 'recomputing; If \'swap\', the engine performs preemption by ' + 'block swapping.') + + parser.add_argument( + "--served-model-name", + nargs="+", + type=str, + default=None, + help="The model name(s) used in the API. If multiple " + "names are provided, the server will respond to any " + "of the provided names. The model name in the model " + "field of a response will be the first name in this " + "list. If not specified, the model name will be the " + "same as the `--model` argument. Noted that this name(s)" + "will also be used in `model_name` tag content of " + "prometheus metrics, if multiple names provided, metrics" + "tag will take the first one.") + parser.add_argument('--qlora-adapter-name-or-path', + type=str, + default=None, + help='Name or path of the QLoRA adapter.') + + parser.add_argument( + '--otlp-traces-endpoint', + type=str, + default=None, + help='Target URL to which OpenTelemetry traces will be sent.') + parser.add_argument( + '--collect-detailed-traces', + type=str, + default=None, + help="Valid choices are " + + ",".join(ALLOWED_DETAILED_TRACE_MODULES) + + ". It makes sense to set this only if --otlp-traces-endpoint is" + " set. If set, it will collect detailed traces for the specified " + "modules. This involves use of possibly costly and or blocking " + "operations and hence might have a performance impact.") + + parser.add_argument( + '--disable-async-output-proc', + action='store_true', + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") + parser.add_argument( + '--override-neuron-config', + type=json.loads, + default=None, + help="Override or set neuron device configuration. " + "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") + + parser.add_argument( + '--scheduling-policy', + choices=['fcfs', 'priority'], + default="fcfs", + help='The scheduling policy to use. "fcfs" (first come first served' + ', i.e. requests are handled in order of arrival; default) ' + 'or "priority" (requests are handled based on given ' + 'priority (lower value means earlier handling) and time of ' + 'arrival deciding any ties).') + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace): + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_model_config(self) -> ModelConfig: + return ModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=True, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + mm_processor_kwargs=self.mm_processor_kwargs, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + def create_engine_config(self) -> EngineConfig: + # gguf file needs a specific model loader and doesn't use hf_repo + if check_gguf_file(self.model): + self.quantization = self.load_format = "gguf" + + # bitsandbytes quantization needs a specific model loader + # so we make sure the quant method and the load format are consistent + if (self.quantization == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.load_format != "bitsandbytes": + raise ValueError( + "BitsAndBytes quantization and QLoRA adapter only support " + f"'bitsandbytes' load format, but got {self.load_format}") + + if (self.load_format == "bitsandbytes" or + self.qlora_adapter_name_or_path is not None) and \ + self.quantization != "bitsandbytes": + raise ValueError( + "BitsAndBytes load format and QLoRA adapter only support " + f"'bitsandbytes' quantization, but got {self.quantization}") + + assert self.cpu_offload_gb >= 0, ( + "CPU offload space must be non-negative" + f", but got {self.cpu_offload_gb}") + + device_config = DeviceConfig(device=self.device) + model_config = self.create_model_config() + + if model_config.is_multimodal_model: + if self.enable_prefix_caching: + logger.warning( + "--enable-prefix-caching is currently not " + "supported for multimodal models and has been disabled.") + self.enable_prefix_caching = False + + maybe_register_config_serialize_by_value(self.trust_remote_code) + + cache_config = CacheConfig( + block_size=self.block_size if self.device != "neuron" else + self.max_model_len, # neuron needs block_size = max_model_len + gpu_memory_utilization=self.gpu_memory_utilization, + swap_space=self.swap_space, + cache_dtype=self.kv_cache_dtype, + is_attention_free=model_config.is_attention_free, + num_gpu_blocks_override=self.num_gpu_blocks_override, + sliding_window=model_config.get_sliding_window(), + enable_prefix_caching=self.enable_prefix_caching, + cpu_offload_gb=self.cpu_offload_gb, + ) + parallel_config = ParallelConfig( + pipeline_parallel_size=self.pipeline_parallel_size, + tensor_parallel_size=self.tensor_parallel_size, + worker_use_ray=self.worker_use_ray, + max_parallel_loading_workers=self.max_parallel_loading_workers, + disable_custom_all_reduce=True, + tokenizer_pool_config=TokenizerPoolConfig.create_config( + self.tokenizer_pool_size, + self.tokenizer_pool_type, + self.tokenizer_pool_extra_config, + ), + ray_workers_use_nsight=self.ray_workers_use_nsight, + distributed_executor_backend=self.distributed_executor_backend) + + max_model_len = model_config.max_model_len + use_long_context = max_model_len > 32768 + + if self.enable_chunked_prefill is None: + # If not explicitly set, enable chunked prefill by default for + # long context (> 32K) models. This is to avoid OOM errors in the + # initial memory profiling phase. + + # Chunked prefill is currently disabled for multimodal models by + # default. + if use_long_context and not model_config.is_multimodal_model: + is_gpu = device_config.device_type == "cuda" + use_sliding_window = (model_config.get_sliding_window() + is not None) + use_spec_decode = self.speculative_model is not None + if (is_gpu and not use_sliding_window and not use_spec_decode + and not self.enable_lora + and not self.enable_prompt_adapter): + self.enable_chunked_prefill = True + logger.warning( + "Chunked prefill is enabled by default for models with " + "max_model_len > 32K. Currently, chunked prefill might " + "not work with some features or models. If you " + "encounter any issues, please disable chunked prefill " + "by setting --enable-chunked-prefill=False.") + if self.enable_chunked_prefill is None: + self.enable_chunked_prefill = False + + if not self.enable_chunked_prefill and use_long_context: + logger.warning( + "The model has a long context length (%s). This may cause OOM " + "errors during the initial memory profiling phase, or result " + "in low performance due to small KV cache space. Consider " + "setting --max-model-len to a smaller value.", max_model_len) + + if self.num_scheduler_steps > 1 and not self.use_v2_block_manager: + self.use_v2_block_manager = True + logger.warning( + "Enabled BlockSpaceManagerV2 because it is " + "required for multi-step (--num-scheduler-steps > 1)") + + speculative_config = SpeculativeConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + target_dtype=self.dtype, + speculative_model=self.speculative_model, + speculative_model_quantization = \ + self.speculative_model_quantization, + speculative_draft_tensor_parallel_size = \ + self.speculative_draft_tensor_parallel_size, + num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, + speculative_disable_by_batch_size=self. + speculative_disable_by_batch_size, + speculative_max_model_len=self.speculative_max_model_len, + enable_chunked_prefill=self.enable_chunked_prefill, + use_v2_block_manager=self.use_v2_block_manager, + disable_log_stats=self.disable_log_stats, + ngram_prompt_lookup_max=self.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=self.ngram_prompt_lookup_min, + draft_token_acceptance_method=\ + self.spec_decoding_acceptance_method, + typical_acceptance_sampler_posterior_threshold=self. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=self. + typical_acceptance_sampler_posterior_alpha, + disable_logprobs=self.disable_logprobs_during_spec_decoding, + ) + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if self.num_scheduler_steps > 1: + if speculative_config is not None: + raise ValueError("Speculative decoding is not supported with " + "multi-step (--num-scheduler-steps > 1)") + if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: + raise ValueError("Multi-Step Chunked-Prefill is not supported " + "for pipeline-parallel-size > 1") + + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.num_scheduler_steps - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + + scheduler_config = SchedulerConfig( + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + use_v2_block_manager=self.use_v2_block_manager, + num_lookahead_slots=num_lookahead_slots, + delay_factor=self.scheduler_delay_factor, + enable_chunked_prefill=self.enable_chunked_prefill, + embedding_mode=model_config.embedding_mode, + is_multimodal_model=model_config.is_multimodal_model, + preemption_mode=self.preemption_mode, + num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER + and parallel_config.use_ray), + policy=self.scheduling_policy, + ) + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + long_lora_scaling_factors=self.long_lora_scaling_factors, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + + if self.qlora_adapter_name_or_path is not None and \ + self.qlora_adapter_name_or_path != "": + if self.model_loader_extra_config is None: + self.model_loader_extra_config = {} + self.model_loader_extra_config[ + "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path + + load_config = self.create_load_config() + + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + + decoding_config = DecodingConfig( + guided_decoding_backend=self.guided_decoding_backend) + + detailed_trace_modules = [] + if self.collect_detailed_traces is not None: + detailed_trace_modules = self.collect_detailed_traces.split(",") + for m in detailed_trace_modules: + if m not in ALLOWED_DETAILED_TRACE_MODULES: + raise ValueError( + f"Invalid module {m} in collect_detailed_traces. " + f"Valid modules are {ALLOWED_DETAILED_TRACE_MODULES}") + observability_config = ObservabilityConfig( + otlp_traces_endpoint=self.otlp_traces_endpoint, + collect_model_forward_time="model" in detailed_trace_modules + or "all" in detailed_trace_modules, + collect_model_execute_time="worker" in detailed_trace_modules + or "all" in detailed_trace_modules, + ) + + if (model_config.get_sliding_window() is not None + and scheduler_config.chunked_prefill_enabled + and not scheduler_config.use_v2_block_manager): + raise ValueError( + "Chunked prefill is not supported with sliding window. " + "Set --disable-sliding-window to disable sliding window.") + + return EngineConfig( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config, + ) + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + disable_log_requests: bool = False + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser, + async_args_only: bool = False) -> FlexibleArgumentParser: + if not async_args_only: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--disable-log-requests', + action='store_true', + help='Disable logging requests.') + return parser + + +class StoreBoolean(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + if values.lower() == "true": + setattr(namespace, self.dest, True) + elif values.lower() == "false": + setattr(namespace, self.dest, False) + else: + raise ValueError(f"Invalid boolean value: {values}. " + "Expected 'true' or 'false'.") + + +# These functions are used by sphinx to build the documentation +def _engine_args_parser(): + return EngineArgs.add_cli_args(FlexibleArgumentParser()) + + +def _async_engine_args_parser(): + return AsyncEngineArgs.add_cli_args(FlexibleArgumentParser(), + async_args_only=True) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py new file mode 100644 index 0000000..aeb48c3 --- /dev/null +++ b/vllm/engine/async_llm_engine.py @@ -0,0 +1,1323 @@ +import asyncio +import time +import weakref +from functools import partial +from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable, + List, Mapping, Optional, Set, Tuple, Type, Union, overload) +from weakref import ReferenceType + +import vllm.envs as envs +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.core.scheduler import SchedulerOutputs +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_timeout import asyncio_timeout +from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState +from vllm.engine.metrics_types import StatLoggerBase +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutorAsync +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptType, TokensPrompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import ExecuteModelRequest +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, + random_uuid, weak_bind) + +logger = init_logger(__name__) +ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _log_task_completion(task: asyncio.Task, + error_callback: Callable[[Exception], None]) -> None: + """This function is only intended for the `engine.run_engine_loop()` task. + + In particular, that task runs a `while True` loop that can only exit if + there is an exception. + """ + + exception = None + try: + return_value = task.result() + raise AssertionError( + f"The engine background task should never finish without an " + f"exception. {return_value}") + except asyncio.exceptions.CancelledError: + # We assume that if the task is cancelled, we are gracefully shutting + # down. This should only happen on program exit. + logger.info("Engine is gracefully shutting down.") + except Exception as e: + exception = e + logger.error("Engine background task failed", exc_info=e) + error_callback(exception) + raise AsyncEngineDeadError( + "Task finished unexpectedly. This should never happen! " + "Please open an issue on Github. See stack trace above for the " + "actual cause.") from e + + +STOP_ITERATION = Exception() # Sentinel + + +class AsyncStream: + """A stream of RequestOutputs or EmbeddingRequestOutputs for a request + that can be iterated over asynchronously via an async generator.""" + + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: + self.request_id = request_id + self._cancel = cancel + self._queue: asyncio.Queue = asyncio.Queue() + self._finished = False + + def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, + Exception]) -> None: + if not self._finished: + self._queue.put_nowait(item) + + def finish( + self, + exception: Optional[Union[BaseException, Type[BaseException]]] = None, + ) -> None: + if not self._finished: + self._finished = True + self._queue.put_nowait( + exception if self._is_raisable(exception) else STOP_ITERATION) + + @property + def finished(self) -> bool: + return self._finished + + async def generator( + self + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + try: + while True: + result = await self._queue.get() + if self._is_raisable(result): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel(self.request_id) + raise asyncio.CancelledError from None + + @staticmethod + def _is_raisable(value: Any): + return isinstance(value, BaseException) or \ + (isinstance(value, type) and \ + issubclass(value, BaseException)) + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, + dict]] = asyncio.Queue() + self.new_requests_event = asyncio.Event() + + def __contains__(self, item): + return item in self._request_streams + + def __len__(self) -> int: + return len(self._request_streams) + + def propagate_exception(self, + exc: Exception, + request_id: Optional[str] = None) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self.abort_request(request_id, exception=exc) + else: + # NB: tuple() used here because self.abort_request pops the stream + # out of self._request_streams, so we can't iterate on it directly + for rid in tuple(self._request_streams.keys()): + self.abort_request(rid, exception=exc) + + def process_request_output(self, + request_output: Union[RequestOutput, + EmbeddingRequestOutput], + *, + verbose: bool = False) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + finished = request_output.finished + + if finished: + stream = self._request_streams.pop(request_id, None) + else: + stream = self._request_streams.get(request_id) + # Guard against a KeyError which can occur if the request was aborted + # while the output was generated + if stream is not None: + stream.put(request_output) + if finished: + stream.finish() + + if verbose and finished: + logger.info("Finished request %s.", request_id) + + def process_exception(self, + request_id: str, + exception: BaseException, + *, + verbose: bool = False) -> None: + """Propagate an exception from the engine.""" + if verbose: + logger.info("Finished request %s.", request_id) + self.abort_request(request_id, exception=exception) + + def add_request(self, + request_id: str, + *, + verbose: bool = False, + **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + abort_request = partial(self.abort_request, verbose=verbose) + stream = AsyncStream(request_id, abort_request) + self._new_requests.put_nowait((stream, { + "request_id": request_id, + **engine_add_request_kwargs + })) + + self.new_requests_event.set() + + if verbose: + logger.info("Added request %s.", request_id) + + return stream + + def abort_request(self, + request_id: str, + *, + exception: Optional[Union[BaseException, + Type[BaseException]]] = None, + verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info("Aborted request %s.", request_id) + + self._aborted_requests.put_nowait(request_id) + + stream = self._request_streams.pop(request_id, None) + if stream is not None: + stream.finish(exception=exception) + + def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._aborted_requests.empty(): + request_id = self._aborted_requests.get_nowait() + finished_requests.add(request_id) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + request_id = stream.request_id + if request_id in finished_requests: + # The request has already been aborted. + stream.finish(asyncio.CancelledError) + finished_requests.discard(request_id) + else: + self._request_streams[request_id] = stream + new_requests.append(new_request) + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + if not self.has_new_requests(): + await self.new_requests_event.wait() + self.new_requests_event.clear() + + def has_new_requests(self): + return not self._new_requests.empty() + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def step_async( + self, virtual_engine: int + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + virtual_engine=virtual_engine, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + # Execute the model. + outputs = await self.model_executor.execute_model_async( + execute_model_req) + + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # Clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len( + outputs + ) == 1, "Async postprocessor expects only a single output set" + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + return ctx.request_outputs + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Stop the remote worker execution loop.""" + await self.model_executor.stop_remote_worker_execution_loop_async() + + @overload # DEPRECATED + async def add_request_async( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload + async def add_request_async( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request_async( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + """Async version of :meth:`add_request`.""" + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + if arrival_time is None: + arrival_time = time.time() + + preprocessed_inputs = await self.input_preprocessor.preprocess_async( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + # Guided decoding has an async implementation for building logits + # processors in a separate threadpool. + # We want to invoke that here instead of using the blocking + # implementation in the LLMEngine + params = await build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=self.get_tokenizer(lora_request), + default_guided_backend=self.decoding_config. + guided_decoding_backend) + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + async def check_health_async(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + +async def build_guided_decoding_logits_processor_async( + sampling_params: SamplingParams, tokenizer: AnyTokenizer, + default_guided_backend: str) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Modifies sampling params in-place and returns + the modified sampling params.""" + if (guided_decoding := sampling_params.guided_decoding) is None: + return sampling_params + + logger.debug("Building guided decoding logits processor. " + "Params: %s", guided_decoding) + + guided_decoding.backend = guided_decoding.backend or default_guided_backend + + processor = await get_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + + if processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append(processor) + + # Unset guided decoding params after constructing the lp from them + sampling_params.guided_decoding = None + + return sampling_params + + +class AsyncLLMEngine: + """An asynchronous wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. + + Args: + log_requests: Whether to log the requests. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + + def __init__(self, + *args, + log_requests: bool = True, + start_engine_loop: bool = True, + **kwargs) -> None: + self.log_requests = log_requests + self.engine = self._engine_class(*args, **kwargs) + + # This ensures quick processing of request outputs + # so the append to asyncio queues is not delayed, + # especially for multi-step. + self.use_process_request_outputs_callback = ( + self.engine.model_config.use_async_output_proc) + + if self.use_process_request_outputs_callback: + self.engine.process_request_outputs_callback = \ + weak_bind(self.process_request_outputs) + + self.background_loop: Optional[asyncio.Future] = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded: Optional[asyncio.Task] = None + self.start_engine_loop = start_engine_loop + self._errored_with: Optional[BaseException] = None + + # Lazy initialized fields + self._request_tracker: RequestTracker + + def __del__(self): + if rt := getattr(self, "request_tracker", None): + # Wake up engine loop so that it will exit cleanly + rt.new_requests_event.set() + + @classmethod + def _get_executor_cls( + cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]: + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorAsyncBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorAsyncBase. Got {distributed_executor_backend}.") + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutorAsync + executor_class = NeuronExecutorAsync + elif engine_config.device_config.device_type == "tpu": + if distributed_executor_backend == "ray": + from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync + executor_class = RayTPUExecutorAsync + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutorAsync + executor_class = TPUExecutorAsync + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutorAsync + executor_class = CPUExecutorAsync + elif engine_config.device_config.device_type == "openvino": + assert distributed_executor_backend is None, ( + "Distributed execution is not supported with " + "the OpenVINO backend.") + from vllm.executor.openvino_executor import OpenVINOExecutorAsync + executor_class = OpenVINOExecutorAsync + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend is None: + from vllm.executor.xpu_executor import XPUExecutorAsync + executor_class = XPUExecutorAsync + elif distributed_executor_backend == "ray": + from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync + executor_class = RayXPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_xpu_executor import ( + MultiprocessingXPUExecutorAsync) + executor_class = MultiprocessingXPUExecutorAsync + else: + raise RuntimeError( + "Not supported distributed execution model on XPU device.") + elif distributed_executor_backend == "ray": + from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync + executor_class = RayGPUExecutorAsync + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutorAsync) + executor_class = MultiprocessingGPUExecutorAsync + else: + from vllm.executor.gpu_executor import GPUExecutorAsync + executor_class = GPUExecutorAsync + return executor_class + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + if engine_config is None: + engine_config = engine_args.create_engine_config() + + executor_class = cls._get_executor_cls(engine_config) + + if executor_class.uses_ray: + initialize_ray_cluster(engine_config.parallel_config) + + # Create the async LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + + @property + def is_running(self) -> bool: + return (self.background_loop is not None + and self._background_loop_unshielded is not None + and not self._background_loop_unshielded.done()) + + @property + def is_stopped(self) -> bool: + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None + and self._background_loop_unshielded.done()) + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + def set_errored(self, exc: Exception) -> None: + self._errored_with = exc + + def _error_callback(self, exc: Exception) -> None: + self.set_errored(exc) + self._request_tracker.propagate_exception(exc) + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return await (self.engine.get_tokenizer_group(). + get_lora_tokenizer_async(lora_request)) + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.errored: + raise AsyncEngineDeadError( + "Background loop has errored already.") from self._errored_with + if self.is_running: + raise RuntimeError("Background loop is already running.") + # Initialize the RequestTracker here so it uses the right event loop. + self._request_tracker = RequestTracker() + + self._background_loop_unshielded = asyncio.get_event_loop( + ).create_task(self.run_engine_loop(weakref.ref(self))) + self._background_loop_unshielded.add_done_callback( + partial(_log_task_completion, error_callback=self._error_callback)) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def shutdown_background_loop(self) -> None: + """ + Shut down the background loop. + + This method needs to be called during cleanup to remove + references to `self` and properly GC the resources held + by the async LLM engine (e.g., the executors as well as + their resources). + """ + if self._background_loop_unshielded is not None: + self._background_loop_unshielded.cancel() + self._background_loop_unshielded = None + self.background_loop = None + + async def engine_step(self, virtual_engine: int) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, aborted_requests = ( + self._request_tracker.get_new_and_aborted_requests()) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + try: + await self.engine.add_request_async(**new_request) + except ValueError as e: + # TODO: use a vLLM specific error for failed validation + self._request_tracker.process_exception( + new_request["request_id"], + e, + verbose=self.log_requests, + ) + + if aborted_requests: + await self._engine_abort(aborted_requests) + + request_outputs = await self.engine.step_async(virtual_engine) + + # Put the outputs into the corresponding streams. + # If used as a callback, then already invoked inside + # LLMEngine's _process_model_outputs + if not self.use_process_request_outputs_callback: + all_finished = self.process_request_outputs(request_outputs) + else: + # For callback case, we only need to detect when all + # requests are finished + all_finished = all(request_output.finished + for request_output in request_outputs) + + return not all_finished + + def process_request_outputs(self, request_outputs) -> bool: + # Put the outputs into the corresponding streams. + all_finished = True + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + all_finished = all_finished and request_output.finished + + return all_finished + + async def _engine_abort(self, request_ids: Iterable[str]): + self.engine.abort_request(request_ids) + + @staticmethod + async def run_engine_loop(engine_ref: ReferenceType): + """We use a weakref to the engine so that the running loop + doesn't prevent the engine being garbage collected.""" + engine: Optional["AsyncLLMEngine"] = engine_ref() + if not engine: + return + + pipeline_parallel_size = \ + engine.engine.parallel_config.pipeline_parallel_size + has_requests_in_progress = [False] * pipeline_parallel_size + while True: + if not any(has_requests_in_progress): + logger.debug("Waiting for new requests...") + # Stop the execute model loop in parallel workers until there + # are more requests to process. This avoids waiting + # indefinitely in torch.distributed ops which may otherwise + # timeout, and unblocks the RPC thread in the workers so that + # they can process any other queued control plane messages, + # such as add/remove lora adapters. + await engine.engine.stop_remote_worker_execution_loop_async() + request_tracker = engine._request_tracker + # Allow engine to be garbage collected while + # waiting for new requests + del engine + await asyncio.sleep(0) + if engine_ref() is None: + return + await request_tracker.wait_for_new_requests() + engine = engine_ref() + if not engine: + return + logger.debug("Got new requests!") + # may be we do not use this when on server mode(all vllm server in new process, but still without cuda graph) + # requests_in_progress = [ + # asyncio.create_task(engine.engine_step(ve)) + # for ve in range(pipeline_parallel_size) + # ] + # has_requests_in_progress = [True] * pipeline_parallel_size + requests_in_progress = {0: asyncio.create_task(engine.engine_step(0))} + has_requests_in_progress[0] = True + + # Abort if iteration takes too long due to unrecoverable errors + # (eg. NCCL timeouts). + try: + async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S): + done, _ = await asyncio.wait( + requests_in_progress.values(), + return_when=asyncio.FIRST_COMPLETED) + for _ in range(len(requests_in_progress)): + await asyncio.sleep(0) + add_virtual_engine = set(list(range(pipeline_parallel_size))) + for task in done: + result = task.result() + # virtual_engine = requests_in_progress.index(task) + virtual_engine = [k for k,v in requests_in_progress.items() if v == task][0] + has_unfinished_requests = ( + engine.engine. + has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if result or has_unfinished_requests: + requests_in_progress[virtual_engine] = ( + asyncio.create_task( + engine.engine_step(virtual_engine))) + has_requests_in_progress[virtual_engine] = True + else: + requests_in_progress.pop(virtual_engine) + has_requests_in_progress[virtual_engine] = False + + # check other engine + for ve in list(add_virtual_engine - requests_in_progress.keys()): + has_unfinished_requests = ( + engine.engine.has_unfinished_requests_for_virtual_engine( + virtual_engine)) + if has_unfinished_requests: + requests_in_progress[ve] = asyncio.create_task(engine.engine_step(ve)) + has_requests_in_progress[ve] = True + except asyncio.TimeoutError as exc: + logger.error( + "Engine iteration timed out. This should never happen!") + engine.set_errored(exc) + raise + await asyncio.sleep(0) + + # This method does not need to be async, but kept that way + # for backwards compatibility. + @overload # DEPRECATED + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Coroutine[None, None, AsyncGenerator[Union[ + RequestOutput, EmbeddingRequestOutput], None]]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + async def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + if (priority != 0 + and not self.engine.scheduler_config.policy == "priority"): + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + stream = self._request_tracker.add_request( + request_id, + verbose=self.log_requests, + prompt=prompt, + params=params, + arrival_time=arrival_time or time.time(), + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + return stream.generator() + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: The priority of the request. + Only applicable with priority scheduling. + + Yields: + The output `RequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in await self.add_request( + request_id, + prompt, + sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ): + yield LLMEngine.validate_output(output, RequestOutput) + + async def beam_search( + self, + prompt: Union[PromptType, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + + tokenizer = await self.get_tokenizer() + tokenizedPrompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + tokenizedLength = len(tokenizedPrompt) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + logger.info(output) + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenizedPrompt, + prompt_logprobs=None) + + yield LLMEngine.validate_output(beam_search_output, RequestOutput) + + async def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "input": "What is LLM?", + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.encode( + >>> example_input["input"], + >>> PoolingParams(), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + async for output in await self.add_request( + request_id, + prompt, + pooling_params, + lora_request=lora_request, + trace_headers=trace_headers, + priority=priority, + ): + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, + exception=asyncio.CancelledError, + verbose=self.log_requests) + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + return self.engine.get_model_config() + + async def get_parallel_config(self) -> ParallelConfig: + """Get the parallel configuration of the vLLM engine.""" + return self.engine.get_parallel_config() + + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + return self.engine.get_decoding_config() + + async def get_scheduler_config(self) -> SchedulerConfig: + """Get the scheduling configuration of the vLLM engine.""" + return self.engine.get_scheduler_config() + + async def get_lora_config(self) -> LoRAConfig: + """Get the lora configuration of the vLLM engine.""" + return self.engine.get_lora_config() + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None) -> None: + self.engine.do_log_stats() + + async def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + t = time.perf_counter() + logger.debug("Starting health check...") + if self.is_stopped: + raise AsyncEngineDeadError("Background loop is stopped.") + + await self.engine.check_health_async() + logger.debug("Health check took %fs", time.perf_counter() - t) + + async def is_tracing_enabled(self) -> bool: + return self.engine.is_tracing_enabled() + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + self.engine.add_logger(logger_name=logger_name, logger=logger) + + def remove_logger(self, logger_name: str) -> None: + self.engine.remove_logger(logger_name=logger_name) + + async def start_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") + + async def stop_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes + if type(self.engine.model_executor) == GPUExecutorAsync: # noqa: E721 + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") \ No newline at end of file diff --git a/vllm/engine/async_timeout.py b/vllm/engine/async_timeout.py new file mode 100644 index 0000000..4b18426 --- /dev/null +++ b/vllm/engine/async_timeout.py @@ -0,0 +1,189 @@ +# Workaround for https://github.com/python/cpython/issues/86296 +# +# From https://github.com/aio-libs/async-timeout/blob/master/async_timeout/__init__.py +# Licensed under the Apache License (Apache-2.0) + +import asyncio +import enum +import sys +import warnings +from types import TracebackType +from typing import Any, Optional, Type + +if sys.version_info[:2] >= (3, 11): + from asyncio import timeout as asyncio_timeout +else: + + def asyncio_timeout(delay: Optional[float]) -> "Timeout": + """timeout context manager. + Useful in cases when you want to apply timeout logic around block + of code or in cases when asyncio.wait_for is not suitable. For example: + >>> async with timeout(0.001): + ... async with aiohttp.get('https://github.com') as r: + ... await r.text() + delay - value in seconds or None to disable timeout logic + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + delay if delay is not None else None + return Timeout(deadline, loop) + + class _State(enum.Enum): + INIT = "INIT" + ENTER = "ENTER" + TIMEOUT = "TIMEOUT" + EXIT = "EXIT" + + class Timeout: + # Internal class, please don't instantiate it directly + # Use timeout() and timeout_at() public factories instead. + # + # Implementation note: `async with timeout()` is preferred + # over `with timeout()`. + # While technically the Timeout class implementation + # doesn't need to be async at all, + # the `async with` statement explicitly points that + # the context manager should be used from async function context. + # + # This design allows to avoid many silly misusages. + # + # TimeoutError is raised immediately when scheduled + # if the deadline is passed. + # The purpose is to time out as soon as possible + # without waiting for the next await expression. + + __slots__ = ("_deadline", "_loop", "_state", "_timeout_handler") + + def __init__(self, deadline: Optional[float], + loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._state = _State.INIT + + self._timeout_handler = None # type: Optional[asyncio.Handle] + if deadline is None: + self._deadline = None # type: Optional[float] + else: + self.update(deadline) + + def __enter__(self) -> "Timeout": + warnings.warn( + "with timeout() is deprecated, use async with timeout()", + DeprecationWarning, + stacklevel=2, + ) + self._do_enter() + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + async def __aenter__(self) -> "Timeout": + self._do_enter() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + self._do_exit(exc_type) + return None + + @property + def expired(self) -> bool: + """Is timeout expired during execution?""" + return self._state == _State.TIMEOUT + + @property + def deadline(self) -> Optional[float]: + return self._deadline + + def reject(self) -> None: + """Reject scheduled timeout if any.""" + # cancel is maybe better name but + # task.cancel() raises CancelledError in asyncio world. + if self._state not in (_State.INIT, _State.ENTER): + raise RuntimeError(f"invalid state {self._state.value}") + self._reject() + + def _reject(self) -> None: + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._timeout_handler = None + + def shift(self, delay: float) -> None: + """Advance timeout on delay seconds. + The delay can be negative. + Raise RuntimeError if shift is called when deadline is not scheduled + """ + deadline = self._deadline + if deadline is None: + raise RuntimeError( + "cannot shift timeout if deadline is not scheduled") + self.update(deadline + delay) + + def update(self, deadline: float) -> None: + """Set deadline to absolute value. + deadline argument points on the time in the same clock system + as loop.time(). + If new deadline is in the past the timeout is raised immediately. + Please note: it is not POSIX time but a time with + undefined starting base, e.g. the time of the system power on. + """ + if self._state == _State.EXIT: + raise RuntimeError( + "cannot reschedule after exit from context manager") + if self._state == _State.TIMEOUT: + raise RuntimeError("cannot reschedule expired timeout") + if self._timeout_handler is not None: + self._timeout_handler.cancel() + self._deadline = deadline + if self._state != _State.INIT: + self._reschedule() + + def _reschedule(self) -> None: + assert self._state == _State.ENTER + deadline = self._deadline + if deadline is None: + return + + now = self._loop.time() + if self._timeout_handler is not None: + self._timeout_handler.cancel() + + task = asyncio.current_task() + if deadline <= now: + self._timeout_handler = self._loop.call_soon( + self._on_timeout, task) + else: + self._timeout_handler = self._loop.call_at( + deadline, self._on_timeout, task) + + def _do_enter(self) -> None: + if self._state != _State.INIT: + raise RuntimeError(f"invalid state {self._state.value}") + self._state = _State.ENTER + self._reschedule() + + def _do_exit(self, exc_type: Optional[Type[BaseException]]) -> None: + if exc_type is asyncio.CancelledError and \ + self._state == _State.TIMEOUT: + self._timeout_handler = None + raise asyncio.TimeoutError + # timeout has not expired + self._state = _State.EXIT + self._reject() + return None + + def _on_timeout(self, task: "Optional[asyncio.Task[Any]]") -> None: + if task: + task.cancel() + self._state = _State.TIMEOUT + # drop the reference early + self._timeout_handler = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py new file mode 100644 index 0000000..563e52a --- /dev/null +++ b/vllm/engine/llm_engine.py @@ -0,0 +1,1934 @@ +import time +from collections import deque +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, + Iterable, List, Mapping, NamedTuple, Optional) +from typing import Sequence as GenericSequence +from typing import Set, Type, Union, overload + +import torch +from typing_extensions import TypeVar + +import vllm.envs as envs +from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, + EngineConfig, LoadConfig, LoRAConfig, ModelConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, + SchedulerOutputs) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics_types import StatLoggerBase, Stats +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.engine.output_processor.util import create_output_by_sequence_group +from vllm.entrypoints.openai.logits_processors import get_logits_processors +from vllm.executor.executor_base import ExecutorBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, + InputRegistry, LLMInputs, PromptType) +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_local_guided_decoding_logits_processor) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, + RequestOutputFactory) +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, + Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import ( + BaseTokenizerGroup, init_tokenizer_from_configs) +from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, + usage_message) +from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + + +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + if config is None: + return {} + + return config.to_diff_dict() + + +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None + + +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + # Indicates if this output is from the first step of the + # multi-step. When multi-step is disabled, this is always + # set to True. + # is_first_step_output is invalid when `outputs` has + # outputs from multiple steps. + is_first_step_output: Optional[bool] + skip: List[int] + + +class SchedulerContext: + + def __init__(self, multi_step_stream_outputs: bool = False): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool, + is_first_step_output: Optional[bool]): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + is_first_step_output=is_first_step_output, + skip=[])) + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. + + The config arguments are derived from :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + lora_config (Optional): The configuration related to serving multi-LoRA. + speculative_config (Optional): The configuration related to speculative + decoding. + executor_class: The model executor class for managing distributed + execution. + prompt_adapter_config (Optional): The configuration related to serving + prompt adapters. + log_stats: Whether to log statistics. + usage_context: Specified entry point, used for usage info collection. + """ + + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + tokenizer: Optional[BaseTokenizerGroup] + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, + ) -> None: + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s)", + VLLM_VERSION, + model_config.model, + speculative_config, + model_config.tokenizer, + model_config.skip_tokenizer_init, + model_config.tokenizer_mode, + model_config.revision, + model_config.override_neuron_config, + model_config.rope_scaling, + model_config.rope_theta, + model_config.tokenizer_revision, + model_config.trust_remote_code, + model_config.dtype, + model_config.max_model_len, + load_config.download_dir, + load_config.load_format, + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + parallel_config.disable_custom_all_reduce, + model_config.quantization, + model_config.enforce_eager, + cache_config.cache_dtype, + model_config.quantization_param_path, + device_config.device, + decoding_config, + observability_config, + model_config.seed, + model_config.served_model_name, + scheduler_config.use_v2_block_manager, + scheduler_config.num_scheduler_steps, + scheduler_config.chunked_prefill_enabled, + scheduler_config.multi_step_stream_outputs, + cache_config.enable_prefix_caching, + model_config.use_async_output_proc, + use_cached_outputs, + model_config.mm_processor_kwargs, + ) + # TODO(woosuk): Print more configs in debug mode. + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.load_config = load_config + self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config or ObservabilityConfig( + ) + self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs + + if not self.model_config.skip_tokenizer_init: + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) + tokenizer_group = self.get_tokenizer_group() + else: + self.tokenizer = None + self.detokenizer = None + tokenizer_group = None + + # Ensure that the function doesn't contain a reference to self, + # to avoid engine GC issues + def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: + assert tokenizer_group, ("tokenizer_group cannot be None, " + "make sure skip_tokenizer_init is False") + return tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict( + model_config) + + self.input_preprocessor = InputPreprocessor(model_config, + self.tokenizer) + + self.input_registry = input_registry + self.input_processor = input_registry.create_input_processor( + model_config) + + self.model_executor = executor_class( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + speculative_config=speculative_config, + load_config=load_config, + prompt_adapter_config=prompt_adapter_config, + observability_config=self.observability_config, + ) + + if not self.model_config.embedding_mode: + self._initialize_kv_caches() + + # If usage stat is enabled, collect relevant info. + if is_usage_stats_enabled(): + from vllm.model_executor.model_loader import ( + get_architecture_class_name) + usage_message.report_usage( + get_architecture_class_name(model_config), + usage_context, + extra_kvs={ + # Common configuration + "dtype": + str(model_config.dtype), + "tensor_parallel_size": + parallel_config.tensor_parallel_size, + "block_size": + cache_config.block_size, + "gpu_memory_utilization": + cache_config.gpu_memory_utilization, + + # Quantization + "quantization": + model_config.quantization, + "kv_cache_dtype": + str(cache_config.cache_dtype), + + # Feature flags + "enable_lora": + bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), + "enable_prefix_caching": + cache_config.enable_prefix_caching, + "enforce_eager": + model_config.enforce_eager, + "disable_custom_all_reduce": + parallel_config.disable_custom_all_reduce, + }) + + if self.tokenizer: + # Ping the tokenizer to ensure liveness if it runs in a + # different process. + self.tokenizer.ping() + + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + if model_config.use_async_output_proc: + process_model_outputs = weak_bind(self._process_model_outputs) + + self.async_callbacks = [ + partial(process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + else: + self.async_callbacks = [] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback: Optional[Callable] = None + + # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. + self.scheduler = [ + Scheduler( + scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size, + self.async_callbacks[v_id] + if model_config.use_async_output_proc else None) + for v_id in range(parallel_config.pipeline_parallel_size) + ] + + # Metric Logging. + if self.log_stats: + if stat_loggers is not None: + self.stat_loggers = stat_loggers + else: + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from vllm.engine.metrics import (LoggingStatLogger, + PrometheusStatLogger) + + self.stat_loggers = { + "logging": + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + "prometheus": + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len), + } + self.stat_loggers["prometheus"].info("cache_config", + self.cache_config) + + self.tracer = None + if self.observability_config.otlp_traces_endpoint: + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + # Create sequence output processor, e.g. for beam search or + # speculative decoding. + self.output_processor = ( + SequenceGroupOutputProcessor.create_output_processor( + self.scheduler_config, + self.detokenizer, + self.scheduler, + self.seq_counter, + get_tokenizer_for_seq, + stop_checker=StopChecker( + self.scheduler_config.max_model_len, + get_tokenizer_for_seq, + ), + )) + + def _initialize_kv_caches(self) -> None: + """Initialize the KV cache in the worker(s). + + The workers will determine the number of blocks in both the GPU cache + and the swap CPU cache. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.model_executor.determine_num_available_blocks()) + + if self.cache_config.num_gpu_blocks_override is not None: + num_gpu_blocks_override = self.cache_config.num_gpu_blocks_override + logger.info( + "Overriding num_gpu_blocks=%d with " + "num_gpu_blocks_override=%d", num_gpu_blocks, + num_gpu_blocks_override) + num_gpu_blocks = num_gpu_blocks_override + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + @classmethod + def _get_executor_cls(cls, + engine_config: EngineConfig) -> Type[ExecutorBase]: + distributed_executor_backend = ( + engine_config.parallel_config.distributed_executor_backend) + # Initialize the cluster and specify the executor class. + if isinstance(distributed_executor_backend, type): + if not issubclass(distributed_executor_backend, ExecutorBase): + raise TypeError( + "distributed_executor_backend must be a subclass of " + f"ExecutorBase. Got {distributed_executor_backend}.") + if distributed_executor_backend.uses_ray: # type: ignore + initialize_ray_cluster(engine_config.parallel_config) + executor_class = distributed_executor_backend + elif engine_config.device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutor + executor_class = NeuronExecutor + elif engine_config.device_config.device_type == "tpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_tpu_executor import RayTPUExecutor + executor_class = RayTPUExecutor + else: + assert distributed_executor_backend is None + from vllm.executor.tpu_executor import TPUExecutor + executor_class = TPUExecutor + elif engine_config.device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + executor_class = CPUExecutor + elif engine_config.device_config.device_type == "openvino": + from vllm.executor.openvino_executor import OpenVINOExecutor + executor_class = OpenVINOExecutor + elif engine_config.device_config.device_type == "xpu": + if distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_xpu_executor import RayXPUExecutor + executor_class = RayXPUExecutor + elif distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") + else: + from vllm.executor.xpu_executor import XPUExecutor + executor_class = XPUExecutor + elif distributed_executor_backend == "ray": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + executor_class = RayGPUExecutor + elif distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor) + assert not envs.VLLM_USE_RAY_SPMD_WORKER, ( + "multiprocessing distributed executor backend does not " + "support VLLM_USE_RAY_SPMD_WORKER=1") + executor_class = MultiprocessingGPUExecutor + else: + from vllm.executor.gpu_executor import GPUExecutor + executor_class = GPUExecutor + return executor_class + + @classmethod + def from_engine_args( + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config() + executor_class = cls._get_executor_cls(engine_config) + # Create the LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + return engine + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if model_executor := getattr(self, "model_executor", None): + model_executor.shutdown() + + def get_tokenizer_group( + self, + group_type: Type[_G] = BaseTokenizerGroup, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError("Unable to get tokenizer because " + "skip_tokenizer_init is True") + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") + + return tokenizer_group + + def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self.get_tokenizer_group().get_lora_tokenizer(lora_request) + + def _init_tokenizer(self) -> BaseTokenizerGroup: + return init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=self.scheduler_config, + parallel_config=self.parallel_config, + enable_lora=bool(self.lora_config)) + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) + + def _add_processed_request( + self, + request_id: str, + processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> None: + self._validate_model_inputs(processed_inputs) + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request) + + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request, prompt_adapter_request) + + encoder_seq = None + if 'encoder_prompt_token_ids' in processed_inputs: + encoder_seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + prompt_adapter_request, + from_decoder_prompt=False) + + # Create a SequenceGroup based on SamplingParams or PoolingParams + if isinstance(params, SamplingParams): + seq_group = self._create_sequence_group_with_sampling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + elif isinstance(params, PoolingParams): + seq_group = self._create_sequence_group_with_pooling( + request_id, + seq, + params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + else: + raise ValueError( + "Either SamplingParams or PoolingParams must be provided.") + + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + + def stop_remote_worker_execution_loop(self) -> None: + self.model_executor.stop_remote_worker_execution_loop() + + @overload # DEPRECATED + def add_request( + self, + request_id: str, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload + def add_request( + self, + request_id: str, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def add_request( + self, + request_id: str, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + params: Parameters for sampling or pooling. + :class:`~vllm.SamplingParams` for text generation. + :class:`~vllm.PoolingParams` for pooling. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + trace_headers: OpenTelemetry trace headers. + priority: The priority of the request. + Only applicable with priority scheduling. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `n` number of :class:`~vllm.Sequence` objects. + - Create a :class:`~vllm.SequenceGroup` object + from the list of :class:`~vllm.Sequence`. + - Add the :class:`~vllm.SequenceGroup` object to the scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if inputs is not None: + prompt = inputs + assert prompt is not None and params is not None + + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + + if priority != 0 and not self.scheduler_config.policy == "priority": + raise ValueError(f"Got priority {priority} but " + "Priority scheduling is not enabled.") + + if arrival_time is None: + arrival_time = time.time() + + preprocessed_inputs = self.input_preprocessor.preprocess( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + processed_inputs = self.input_processor(preprocessed_inputs) + + # This is a bit of a hack - copy the mm_processor_kwargs that were + # used in the input processor to the processed output, since these + # kwargs are presumed to be immutable and the values should be aligned + # between the input processor (here) and the input mapper. + processed_inputs["mm_processor_kwargs"] = preprocessed_inputs.get( + "mm_processor_kwargs") + + self._add_processed_request( + request_id=request_id, + processed_inputs=processed_inputs, + params=params, + arrival_time=arrival_time, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=priority, + ) + + def _create_sequence_group_with_sampling( + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with SamplingParams.""" + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + sampling_params = self._build_logits_processors( + sampling_params, lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + + sampling_params.update_from_generation_config( + self.generation_config_fields, seq.eos_token_id) + + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + sampling_params=sampling_params, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + + return seq_group + + def _create_sequence_group_with_pooling( + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + encoder_seq: Optional[Sequence] = None, + priority: int = 0, + ) -> SequenceGroup: + """Creates a SequenceGroup with PoolingParams.""" + # Defensive copy of PoolingParams, which are used by the pooler + pooling_params = pooling_params.clone() + # Create the sequence group. + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request, + encoder_seq=encoder_seq, + priority=priority) + return seq_group + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to the + :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` + from class :class:`~vllm.core.scheduler.Scheduler`. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + for scheduler in self.scheduler: + scheduler.abort_seq_group(request_id) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_parallel_config(self) -> ParallelConfig: + """Gets the parallel configuration.""" + return self.parallel_config + + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + + def get_scheduler_config(self) -> SchedulerConfig: + """Gets the scheduler configuration.""" + return self.scheduler_config + + def get_lora_config(self) -> LoRAConfig: + """Gets the LoRA configuration.""" + return self.lora_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return sum(scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler) + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return any(scheduler.has_unfinished_seqs() + for scheduler in self.scheduler) + + def has_unfinished_requests_for_virtual_engine( + self, virtual_engine: int) -> bool: + """ + Returns True if there are unfinished requests for the virtual engine. + """ + return self.scheduler[virtual_engine].has_unfinished_seqs() + + @staticmethod + def _process_sequence_group_outputs( + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], + ) -> None: + seq_group.embeddings = outputs[0].embeddings + + for seq in seq_group.get_seqs(): + seq.status = SequenceStatus.FINISHED_STOPPED + + return + + def _update_num_computed_tokens_for_multi_step_prefill( + self, seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. + + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed + """ + + now = time.time() + + if len(ctx.output_queue) == 0: + return None + + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, is_first_step_output, + skip) = ctx.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + has_multiple_outputs: bool = len(outputs) > 1 + if has_multiple_outputs: + assert self.scheduler_config.is_multi_step or \ + self.speculative_config + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, num_seq_groups=len(seq_group_metadata_list)) + # We have outputs for multiple steps submitted in a single burst, + # so invalidate is_first_step_output. + is_first_step_output = None + else: + outputs_by_sequence_group = outputs + + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + + finished_before: List[int] = [] + finished_now: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group: SequenceGroup = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + finished_before.append(i) + continue + + if has_multiple_outputs: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) + + if outputs: + for o in outputs: + if (isinstance(o, SamplerOutput) + and seq_group.metrics is not None): + if seq_group.metrics.model_forward_time is not None: + seq_group.metrics.model_forward_time += ( + o.model_forward_time) + else: + seq_group.metrics.model_forward_time = ( + o.model_forward_time) + if seq_group.metrics.model_execute_time is not None: + seq_group.metrics.model_execute_time += ( + o.model_execute_time) + else: + seq_group.metrics.model_execute_time = ( + o.model_execute_time) + + if self.model_config.embedding_mode: + self._process_sequence_group_outputs(seq_group, output) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) + + if seq_group.is_finished(): + finished_now.append(i) + + # Generate outputs for the requests that finished this iteration + for i in finished_now: + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now: + continue # Avoids double processing + + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + for seq_group in scheduler_outputs.ignored_seq_groups: + params = seq_group.sampling_params + if params is not None and params.output_kind == ( + RequestOutputKind.DELTA) and not seq_group.is_finished(): + continue + + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) + if request_output: + ctx.request_outputs.append(request_output) + + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) + + # Tracing + self.do_tracing(scheduler_outputs) + + return None + + def _advance_to_next_step( + self, output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) + else: + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 + seq.append_token_id(sample.output_token, sample.logprobs) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs) + + def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + """Performs one decoding iteration and returns newly generated results. + + .. figure:: https://i.imgur.com/sv2HssD.png + :alt: Overview of the step function + :align: center + + Overview of the step function. + + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the distributed executor to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + >>> # Please see the example/ folder for more detailed examples. + >>> + >>> # initialize engine and request arguments + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> example_inputs = [(0, "What is LLM?", + >>> SamplingParams(temperature=0.0))] + >>> + >>> # Start the engine with an event loop + >>> while True: + >>> if example_inputs: + >>> req_id, prompt, sampling_params = example_inputs.pop(0) + >>> engine.add_request(str(req_id),prompt,sampling_params) + >>> + >>> # continue the request processing + >>> request_outputs = engine.step() + >>> for request_output in request_outputs: + >>> if request_output.finished: + >>> # return or show the request output + >>> + >>> if not (engine.has_unfinished_requests() or example_inputs): + >>> break + """ + if self.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is only supported through AsyncLLMEngine " + "as performance will be severely degraded otherwise.") + + # For llm_engine, there is no pipeline parallel support, so the engine + # used is always 0. + virtual_engine = 0 + + # These are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + + ctx = self.scheduler_contexts[virtual_engine] + + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + + # Skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + # Schedule iteration + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + # Maybe switch from async mode to sync mode + if not allow_async_output_proc and len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None + + if not scheduler_outputs.is_empty(): + finished_requests_ids = self.scheduler[ + virtual_engine].get_and_reset_finished_requests_ids() + + # Check if we have a cached last_output from the previous iteration. + # For supporting PP this is probably the best way to pass the + # sampled_token_ids, as a separate broadcast over all the PP stages + # will cause one virtual engine's microbatch to block the pipeline. + last_sampled_token_ids = \ + self._get_last_sampled_token_ids(virtual_engine) + + execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + num_lookahead_slots=scheduler_outputs.num_lookahead_slots, + running_queue_size=scheduler_outputs.running_queue_size, + finished_requests_ids=finished_requests_ids, + # We use ExecuteModelRequest to pass the last sampled_token_ids + # to each of the non-last PP stages for in-place prepare_input. + last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.async_callback = self.async_callbacks[ + virtual_engine] + + outputs = self.model_executor.execute_model( + execute_model_req=execute_model_req) + + # We need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._update_cached_scheduler_output(virtual_engine, outputs) + else: + # Nothing scheduled => If there is pending async postprocessor, + # then finish it here. + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + # No outputs in this case + outputs = [] + + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps. + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[0] = SchedulerOutputState() + + # is_first_step_output is True only when the num_steps of all + # the sequences are 1. When the num_steps > 1, + # multi_step_model_runner does the first-step output append. + is_first_step_output: bool = False if not seq_group_metadata_list \ + else seq_group_metadata_list[0].state.num_steps == 1 + + # Add results to the output_queue + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True, + is_first_step_output=is_first_step_output) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") + + self._advance_to_next_step( + outputs[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + # Check if need to run the usual non-async path + if not allow_async_output_proc: + self._process_model_outputs(ctx=ctx) + + # Log stats. + self.do_log_stats(scheduler_outputs, outputs) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + # Multi-step case + return ctx.request_outputs + + if not self.has_unfinished_requests(): + # Drain async postprocessor (if exists) + if len(ctx.output_queue) > 0: + self._process_model_outputs(ctx=ctx) + assert len(ctx.output_queue) == 0 + + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + logger.debug("Stopping remote worker execution loop.") + self.model_executor.stop_remote_worker_execution_loop() + + return ctx.request_outputs + + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if (not self.scheduler_config.is_multi_step + or not seq_group_metadata_list): + return False + + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + ref_remaining_steps = seq_group_metadata_list[0].state.remaining_steps + if any([ + seq_group.state.remaining_steps != ref_remaining_steps + for seq_group in seq_group_metadata_list[1:] + ]): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + return ref_remaining_steps > 0 + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None + + def _update_cached_scheduler_output( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1 and len(output) > 0 + and output[0] is not None): + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_cpu is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + + def _get_last_sampled_token_ids( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_cpu is not None): + return cached_last_output.sampled_token_ids_cpu + return None + + def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} already exists.") + self.stat_loggers[logger_name] = logger + + def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") + if logger_name not in self.stat_loggers: + raise KeyError(f"Logger with name {logger_name} does not exist.") + del self.stat_loggers[logger_name] + + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: + """Forced log when no requests active.""" + if self.log_stats: + stats = self._get_stats(scheduler_outputs, model_output, + finished_before, skip) + for logger in self.stat_loggers.values(): + logger.log(stats) + + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: + """Get Stats to be Logged to Prometheus. + + Args: + scheduler_outputs: Optional, used to populate metrics related to + the scheduled batch, + model_output: Optional, used to emit speculative decoding metrics + which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. + """ + now = time.time() + + # System State + # Scheduler State + num_running_sys = sum( + len(scheduler.running) for scheduler in self.scheduler) + num_swapped_sys = sum( + len(scheduler.swapped) for scheduler in self.scheduler) + num_waiting_sys = sum( + len(scheduler.waiting) for scheduler in self.scheduler) + + # KV Cache Usage in % + num_total_gpu = self.cache_config.num_gpu_blocks + gpu_cache_usage_sys = 0. + if num_total_gpu is not None: + num_free_gpu = sum( + scheduler.block_manager.get_num_free_gpu_blocks() + for scheduler in self.scheduler) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) + + num_total_cpu = self.cache_config.num_cpu_blocks + cpu_cache_usage_sys = 0. + if num_total_cpu is not None and num_total_cpu > 0: + num_free_cpu = sum( + scheduler.block_manager.get_num_free_cpu_blocks() + for scheduler in self.scheduler) + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + num_preemption_iter = (0 if scheduler_outputs is None else + scheduler_outputs.preempted) + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + n_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. + if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + + num_generation_tokens_from_prefill_groups = 0. + # NOTE: if scheduler_outputs.num_prefill_groups > 0 and + # the len of scheduler_outputs.scheduled_seq_groups is != + # scheduler_outputs.num_prefill_groups, this means that + # chunked prefills have been detected. + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + + group_was_prefill = idx < scheduler_outputs.num_prefill_groups + seq_group = scheduled_seq_group.seq_group + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_latency(now) + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_latency(now) + time_per_output_tokens_iter.append(latency) + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. + if seq_group.is_finished(): + # Latency timings + time_e2e_requests.append(now - + seq_group.metrics.arrival_time) + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + if seq_group.sampling_params is not None: + n_requests.append(seq_group.sampling_params.n) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + actual_num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) + + # Spec decode, if enabled, emits specialized metrics from the worker in + # sampler output. + if model_output and (model_output[0].spec_decode_worker_metrics + is not None): + spec_decode_metrics = model_output[0].spec_decode_worker_metrics + else: + spec_decode_metrics = None + + return Stats( + now=now, + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, + spec_decode_metrics=spec_decode_metrics, + num_preemption_iter=num_preemption_iter, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + n_requests=n_requests, + finished_reason_requests=finished_reason_requests, + ) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_executor.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_executor.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_executor.list_loras() + + def pin_lora(self, lora_id: int) -> bool: + return self.model_executor.pin_lora(lora_id) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + + def check_health(self) -> None: + if self.tokenizer: + self.tokenizer.check_health() + self.model_executor.check_health() + + def start_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.start_profile() + else: + self.model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + # using type instead of isinstance to check to avoid capturing + # inherited classes (MultiprocessingGPUExecutor) + if type(self.model_executor) == GPUExecutor: # noqa: E721 + self.model_executor.stop_profile() + else: + self.model_executor._run_workers("stop_profile") + + def is_tracing_enabled(self) -> bool: + return self.tracer is not None + + def do_tracing(self, scheduler_outputs: SchedulerOutputs) -> None: + if self.tracer is None: + return + + for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group = scheduled_seq_group.seq_group + if seq_group.is_finished(): + self.create_trace_span(seq_group) + + def create_trace_span(self, seq_group: SequenceGroup) -> None: + if self.tracer is None or seq_group.sampling_params is None: + return + arrival_time_nano_seconds = int(seq_group.metrics.arrival_time * 1e9) + + trace_context = extract_trace_context(seq_group.trace_headers) + + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as seq_span: + metrics = seq_group.metrics + ttft = metrics.first_token_time - metrics.arrival_time + e2e_time = metrics.finished_time - metrics.arrival_time + # attribute names are based on + # https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/llm-spans.md + seq_span.set_attribute(SpanAttributes.LLM_RESPONSE_MODEL, + self.model_config.model) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_ID, + seq_group.request_id) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TEMPERATURE, + seq_group.sampling_params.temperature) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_TOP_P, + seq_group.sampling_params.top_p) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS, + seq_group.sampling_params.max_tokens) + seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N, + seq_group.sampling_params.n) + seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES, + seq_group.num_seqs()) + seq_span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, + len(seq_group.prompt_token_ids)) + seq_span.set_attribute( + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, + sum([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ])) + seq_span.set_attribute(SpanAttributes.LLM_LATENCY_TIME_IN_QUEUE, + metrics.time_in_queue) + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) + if metrics.scheduler_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_SCHEDULER, + metrics.scheduler_time) + if metrics.model_forward_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_FORWARD, + metrics.model_forward_time / 1000.0) + if metrics.model_execute_time is not None: + seq_span.set_attribute( + SpanAttributes.LLM_LATENCY_TIME_IN_MODEL_EXECUTE, + metrics.model_execute_time) + + def is_encoder_decoder_model(self): + return self.input_preprocessor.is_encoder_decoder_model() + + def is_embedding_model(self): + return self.model_config.is_embedding_model + + def _validate_model_inputs(self, inputs: Union[LLMInputs, + EncoderDecoderLLMInputs]): + if self.model_config.is_multimodal_model: + # For encoder-decoder multimodal models, the max_prompt_len + # restricts the decoder prompt length + prompt_ids = inputs.get("prompt_token_ids") + elif self.is_encoder_decoder_model(): + prompt_ids = inputs.get("encoder_prompt_token_ids") + else: + prompt_ids = inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _build_logits_processors( + self, sampling_params: SamplingParams, + lora_request: Optional[LoRARequest]) -> SamplingParams: + """Constructs logits processors based on the guided_decoding, + logits_bias, and allowed_token_ids fields in sampling_params. Deletes + those fields and adds the constructed logits processors to the + logits_processors field. Returns the modified sampling params.""" + + logits_processors = [] + if (guided_decoding := sampling_params.guided_decoding) is not None: + + logger.debug( + "Building guided decoding logits processor in " + "LLMEngine. Params: %s", guided_decoding) + + tokenizer = self.get_tokenizer(lora_request=lora_request) + guided_decoding.backend = guided_decoding.backend or \ + self.decoding_config.guided_decoding_backend + + processor = get_local_guided_decoding_logits_processor( + guided_params=guided_decoding, tokenizer=tokenizer) + if processor: + logits_processors.append(processor) + + # Unset so this doesn't get passed down to the model + sampling_params.guided_decoding = None + + if (sampling_params.logit_bias or sampling_params.allowed_token_ids): + tokenizer = self.get_tokenizer(lora_request=lora_request) + + processors = get_logits_processors( + logit_bias=sampling_params.logit_bias, + allowed_token_ids=sampling_params.allowed_token_ids, + tokenizer=tokenizer) + logits_processors.extend(processors) + + # Unset so these don't get passed down to the model + sampling_params.logit_bias = None + sampling_params.allowed_token_ids = None + + if logits_processors: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = logits_processors + else: + sampling_params.logits_processors.extend(logits_processors) + + return sampling_params diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py new file mode 100644 index 0000000..42acd3e --- /dev/null +++ b/vllm/engine/metrics.py @@ -0,0 +1,555 @@ +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Union + +import numpy as np +import prometheus_client + +from vllm.engine.metrics_types import (StatLoggerBase, Stats, + SupportsMetricsInfo) +from vllm.executor.ray_utils import ray +from vllm.logger import init_logger + +if ray is not None: + from ray.util import metrics as ray_metrics +else: + ray_metrics = None + +if TYPE_CHECKING: + from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +logger = init_logger(__name__) + +prometheus_client.disable_created_metrics() + +# The begin-* and end* here are used by the documentation generator +# to extract the metrics definitions. + + +# begin-metrics-definitions +class Metrics: + """ + vLLM uses a multiprocessing-based frontend for the OpenAI server. + This means that we need to run prometheus_client in multiprocessing mode + See https://prometheus.github.io/client_python/multiprocess/ for more + details on limitations. + """ + labelname_finish_reason = "finished_reason" + _gauge_cls = prometheus_client.Gauge + _counter_cls = prometheus_client.Counter + _histogram_cls = prometheus_client.Histogram + + def __init__(self, labelnames: List[str], max_model_len: int): + # Unregister any existing vLLM collectors (for CI/CD) + self._unregister_vllm_metrics() + + # System stats + # Scheduler State + self.gauge_scheduler_running = self._gauge_cls( + name="vllm:num_requests_running", + documentation="Number of requests currently running on GPU.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_scheduler_waiting = self._gauge_cls( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_scheduler_swapped = self._gauge_cls( + name="vllm:num_requests_swapped", + documentation="Number of requests swapped to CPU.", + labelnames=labelnames, + multiprocess_mode="sum") + # KV Cache Usage in % + self.gauge_gpu_cache_usage = self._gauge_cls( + name="vllm:gpu_cache_usage_perc", + documentation="GPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_cpu_cache_usage = self._gauge_cls( + name="vllm:cpu_cache_usage_perc", + documentation="CPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames, + multiprocess_mode="sum") + # Prefix caching block hit rate + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:cpu_prefix_cache_hit_rate", + documentation="CPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:gpu_prefix_cache_hit_rate", + documentation="GPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + + # Iteration stats + self.counter_num_preemption = self._counter_cls( + name="vllm:num_preemptions_total", + documentation="Cumulative number of preemption from the engine.", + labelnames=labelnames) + self.counter_prompt_tokens = self._counter_cls( + name="vllm:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames) + self.counter_generation_tokens = self._counter_cls( + name="vllm:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames) + self.histogram_time_to_first_token = self._histogram_cls( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labelnames, + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 + ]) + self.histogram_time_per_output_token = self._histogram_cls( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + # Request stats + # Latency + self.histogram_e2e_time_request = self._histogram_cls( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + labelnames=labelnames, + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) + # Metadata + self.histogram_num_prompt_tokens_request = self._histogram_cls( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = \ + self._histogram_cls( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_n_request = self._histogram_cls( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.counter_request_success = self._counter_cls( + name="vllm:request_success_total", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) + + # Speculatie decoding stats + self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls( + name="vllm:spec_decode_draft_acceptance_rate", + documentation="Speulative token acceptance rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_spec_decode_efficiency = self._gauge_cls( + name="vllm:spec_decode_efficiency", + documentation="Speculative decoding system efficiency.", + labelnames=labelnames, + multiprocess_mode="sum") + self.counter_spec_decode_num_accepted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_accepted_tokens_total", + documentation="Number of accepted tokens.", + labelnames=labelnames)) + self.counter_spec_decode_num_draft_tokens = self._counter_cls( + name="vllm:spec_decode_num_draft_tokens_total", + documentation="Number of draft tokens.", + labelnames=labelnames) + self.counter_spec_decode_num_emitted_tokens = (self._counter_cls( + name="vllm:spec_decode_num_emitted_tokens_total", + documentation="Number of emitted tokens.", + labelnames=labelnames)) + + # Deprecated in favor of vllm:prompt_tokens_total + self.gauge_avg_prompt_throughput = self._gauge_cls( + name="vllm:avg_prompt_throughput_toks_per_s", + documentation="Average prefill throughput in tokens/s.", + labelnames=labelnames, + multiprocess_mode="sum", + ) + # Deprecated in favor of vllm:generation_tokens_total + self.gauge_avg_generation_throughput = self._gauge_cls( + name="vllm:avg_generation_throughput_toks_per_s", + documentation="Average generation throughput in tokens/s.", + labelnames=labelnames, + multiprocess_mode="sum", + ) + + +# end-metrics-definitions + + def _unregister_vllm_metrics(self) -> None: + for collector in list(prometheus_client.REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + prometheus_client.REGISTRY.unregister(collector) + + +class _RayGaugeWrapper: + """Wraps around ray.util.metrics.Gauge to provide same API as + prometheus_client.Gauge""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + multiprocess_mode: str = ""): + del multiprocess_mode + labelnames_tuple = tuple(labelnames) if labelnames else None + self._gauge = ray_metrics.Gauge(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._gauge.set_default_tags(labels) + return self + + def set(self, value: Union[int, float]): + return self._gauge.set(value) + + +class _RayCounterWrapper: + """Wraps around ray.util.metrics.Counter to provide same API as + prometheus_client.Counter""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._counter = ray_metrics.Counter(name=name, + description=documentation, + tag_keys=labelnames_tuple) + + def labels(self, **labels): + self._counter.set_default_tags(labels) + return self + + def inc(self, value: Union[int, float] = 1.0): + if value == 0: + return + return self._counter.inc(value) + + +class _RayHistogramWrapper: + """Wraps around ray.util.metrics.Histogram to provide same API as + prometheus_client.Histogram""" + + def __init__(self, + name: str, + documentation: str = "", + labelnames: Optional[List[str]] = None, + buckets: Optional[List[float]] = None): + labelnames_tuple = tuple(labelnames) if labelnames else None + self._histogram = ray_metrics.Histogram(name=name, + description=documentation, + tag_keys=labelnames_tuple, + boundaries=buckets) + + def labels(self, **labels): + self._histogram.set_default_tags(labels) + return self + + def observe(self, value: Union[int, float]): + return self._histogram.observe(value) + + +class RayMetrics(Metrics): + """ + RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics. + Provides the same metrics as Metrics but uses Ray's util.metrics library. + """ + _gauge_cls = _RayGaugeWrapper + _counter_cls = _RayCounterWrapper + _histogram_cls = _RayHistogramWrapper + + def __init__(self, labelnames: List[str], max_model_len: int): + if ray_metrics is None: + raise ImportError("RayMetrics requires Ray to be installed.") + super().__init__(labelnames, max_model_len) + + def _unregister_vllm_metrics(self) -> None: + # No-op on purpose + pass + + +def build_1_2_5_buckets(max_value: int) -> List[int]: + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values (1, 2, 5) until the value exceeds the specified maximum. + + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + mantissa_lst = [1, 2, 5] + exponent = 0 + buckets: List[int] = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + +def local_interval_elapsed(now: float, last_log: float, + local_interval: float) -> bool: + elapsed_time = now - last_log + return elapsed_time > local_interval + + +def get_throughput(tracked_stats: List[int], now: float, + last_log: float) -> float: + return float(np.sum(tracked_stats) / (now - last_log)) + + +class LoggingStatLogger(StatLoggerBase): + """LoggingStatLogger is used in LLMEngine to log to Stdout.""" + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + Logs to Stdout every self.local_interval seconds.""" + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + # Compute summary metrics for tracked stats (and log them + # to promethus if applicable). + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + + # Log to stdout. + logger.info( + "Avg prompt throughput: %.1f tokens/s, " + "Avg generation throughput: %.1f tokens/s, " + "Running: %d reqs, Swapped: %d reqs, " + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%.", + prompt_throughput, + generation_throughput, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, + ) + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + logger.info( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) + if self.spec_decode_metrics is not None: + logger.info( + self._format_spec_decode_metrics_str( + self.spec_decode_metrics)) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + + def _format_spec_decode_metrics_str( + self, metrics: "SpecDecodeWorkerMetrics") -> str: + + return ("Speculative metrics: " + f"Draft acceptance rate: {metrics.draft_acceptance_rate:.3f}, " + f"System efficiency: {metrics.system_efficiency:.3f}, " + f"Number of speculative tokens: {metrics.num_spec_tokens}, " + f"Number of accepted tokens: {metrics.accepted_tokens}, " + f"Number of draft tokens: {metrics.draft_tokens}, " + f"Number of emitted tokens: {metrics.emitted_tokens}.") + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + +class PrometheusStatLogger(StatLoggerBase): + """PrometheusStatLogger is used LLMEngine to log to Promethus.""" + _metrics_cls = Metrics + _gauge_cls = prometheus_client.Gauge + + def __init__(self, local_interval: float, labels: Dict[str, str], + max_model_len: int) -> None: + super().__init__(local_interval) + # Prometheus metrics + self.labels = labels + self.metrics = self._metrics_cls(labelnames=list(labels.keys()), + max_model_len=max_model_len) + + def _log_gauge(self, gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram, data: Union[List[int], + List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) + + def _log_prometheus(self, stats: Stats) -> None: + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_swapped, + stats.num_swapped_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_cache_usage, + stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) + + # Iteration level data + self._log_counter(self.metrics.counter_num_preemption, + stats.num_preemption_iter) + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + + def _log_prometheus_interval(self, prompt_throughput: float, + generation_throughput: float) -> None: + # Logs metrics to prometheus that are computed every logging_interval. + # Support legacy gauge metrics that make throughput calculations on + # the vLLM side. Moving forward, we should use counters like + # counter_prompt_tokens, counter_generation_tokens + # Which log raw data and calculate summaries using rate() on the + # grafana/prometheus side. See + # https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666 + self.metrics.gauge_avg_prompt_throughput.labels( + **self.labels).set(prompt_throughput) + self.metrics.gauge_avg_generation_throughput.labels( + **self.labels).set(generation_throughput) + + def log(self, stats: Stats): + """Logs to prometheus and tracked stats every iteration.""" + # Log to prometheus. + self._log_prometheus(stats) + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) + + # Update spec decode metrics + self.maybe_update_spec_decode_metrics(stats) + + # Log locally every local_interval seconds. + if local_interval_elapsed(stats.now, self.last_local_log, + self.local_interval): + # Compute summary metrics for tracked stats (and log them + # to promethus if applicable). + prompt_throughput = get_throughput(self.num_prompt_tokens, + now=stats.now, + last_log=self.last_local_log) + generation_throughput = get_throughput( + self.num_generation_tokens, + now=stats.now, + last_log=self.last_local_log) + + self._log_prometheus_interval( + prompt_throughput=prompt_throughput, + generation_throughput=generation_throughput) + + if self.spec_decode_metrics is not None: + self._log_gauge( + self.metrics.gauge_spec_decode_draft_acceptance_rate, + self.spec_decode_metrics.draft_acceptance_rate) + self._log_gauge(self.metrics.gauge_spec_decode_efficiency, + self.spec_decode_metrics.system_efficiency) + self._log_counter( + self.metrics.counter_spec_decode_num_accepted_tokens, + self.spec_decode_metrics.accepted_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_draft_tokens, + self.spec_decode_metrics.draft_tokens) + self._log_counter( + self.metrics.counter_spec_decode_num_emitted_tokens, + self.spec_decode_metrics.emitted_tokens) + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now + self.spec_decode_metrics = None + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + # Info type metrics are syntactic sugar for a gauge permanently set to 1 + # Since prometheus multiprocessing mode does not support Info, emulate + # info here with a gauge. + if type == "cache_config": + metrics_info = obj.metrics_info() + info_gauge = self._gauge_cls( + name="vllm:cache_config_info", + documentation="Information of the LLMEngine CacheConfig", + labelnames=metrics_info.keys(), + multiprocess_mode="mostrecent") + info_gauge.labels(**metrics_info).set(1) + + +class RayPrometheusStatLogger(PrometheusStatLogger): + """RayPrometheusStatLogger uses Ray metrics instead.""" + _metrics_cls = RayMetrics + + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + return None diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py new file mode 100644 index 0000000..bafd5fa --- /dev/null +++ b/vllm/engine/metrics_types.py @@ -0,0 +1,87 @@ +""" +These types are defined in this file to avoid importing vllm.engine.metrics +and therefore importing prometheus_client. + +This is required due to usage of Prometheus multiprocess mode to enable +metrics after splitting out the uvicorn process from the engine process. + +Prometheus multiprocess mode requires setting PROMETHEUS_MULTIPROC_DIR +before prometheus_client is imported. Typically, this is done by setting +the env variable before launch, but since we are a library, we need to +do this in Python code and lazily import prometheus_client. +""" + +import time +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, List, Optional, Protocol + +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + num_preemption_iter: int + + # Request stats (should have _requests suffix) + # Latency + time_e2e_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + n_requests: List[int] + finished_reason_requests: List[str] + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + +class SupportsMetricsInfo(Protocol): + + def metrics_info(self) -> Dict[str, str]: + ... + + +class StatLoggerBase(ABC): + """Base class for StatLogger.""" + + def __init__(self, local_interval: float) -> None: + # Tracked stats over current local logging interval. + self.num_prompt_tokens: List[int] = [] + self.num_generation_tokens: List[int] = [] + self.last_local_log = time.time() + self.local_interval = local_interval + self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + + @abstractmethod + def log(self, stats: Stats) -> None: + raise NotImplementedError + + @abstractmethod + def info(self, type: str, obj: SupportsMetricsInfo) -> None: + raise NotImplementedError + + def maybe_update_spec_decode_metrics(self, stats: Stats): + """Save spec decode metrics (since they are unlikely + to be emitted at same time as log interval).""" + if stats.spec_decode_metrics is not None: + self.spec_decode_metrics = stats.spec_decode_metrics diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py new file mode 100644 index 0000000..34c161e --- /dev/null +++ b/vllm/engine/multiprocessing/__init__.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Mapping, Optional, Union, overload + +from vllm import PoolingParams +from vllm.inputs import PromptType +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.utils import deprecate_kwargs + +VLLM_RPC_SUCCESS_STR = "SUCCESS" + +IPC_INPUT_EXT = "_input_socket" +IPC_OUTPUT_EXT = "_output_socket" +IPC_HEALTH_EXT = "_health_socket" +IPC_DATA_EXT = "_data_socket" + + +class MQEngineDeadError(RuntimeError): + pass + + +@dataclass +class RPCProcessRequest: + prompt: PromptType + params: Union[SamplingParams, PoolingParams] + request_id: str + lora_request: Optional[LoRARequest] = None + trace_headers: Optional[Mapping[str, str]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + priority: int = 0 + + @overload # DEPRECATED + def __init__( + self, + *, + inputs: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @overload + def __init__( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def __init__( + self, + prompt: Optional[PromptType] = None, + params: Optional[Union[SamplingParams, PoolingParams]] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None, # DEPRECATED + ) -> None: + if inputs is not None: + prompt = inputs + assert (prompt is not None and params is not None + and request_id is not None) + + super().__init__() + + self.prompt = prompt + self.params = params + self.request_id = request_id + self.lora_request = lora_request + self.trace_headers = trace_headers + self.prompt_adapter_request = prompt_adapter_request + self.priority = priority + + +@dataclass +class RPCError: + request_id: Optional[str] + is_engine_errored: bool + exception: BaseException + + +@dataclass +class RPCAbortRequest: + request_id: str + + +class RPCStartupRequest(Enum): + IS_SERVER_READY = 1 + + +@dataclass +class RPCStartupResponse: + tracing_enabled: bool + + +class RPCUProfileRequest(Enum): + START_PROFILE = 1 + STOP_PROFILE = 2 + + +RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, + RPCUProfileRequest] + +REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] + + +def ENGINE_DEAD_ERROR( + error: Optional[BaseException] = None) -> MQEngineDeadError: + if error is None: + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + "find the original error") + + return MQEngineDeadError( + "Engine loop is not running. Inspect the stacktrace to " + f"find the original error: {repr(error)}.") diff --git a/vllm/engine/multiprocessing/__pycache__/__init__.cpython-310.pyc b/vllm/engine/multiprocessing/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..025eaf5 Binary files /dev/null and b/vllm/engine/multiprocessing/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/engine/multiprocessing/__pycache__/client.cpython-310.pyc b/vllm/engine/multiprocessing/__pycache__/client.cpython-310.pyc new file mode 100644 index 0000000..8b83b75 Binary files /dev/null and b/vllm/engine/multiprocessing/__pycache__/client.cpython-310.pyc differ diff --git a/vllm/engine/multiprocessing/__pycache__/engine.cpython-310.pyc b/vllm/engine/multiprocessing/__pycache__/engine.cpython-310.pyc new file mode 100644 index 0000000..2ebdd9c Binary files /dev/null and b/vllm/engine/multiprocessing/__pycache__/engine.cpython-310.pyc differ diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py new file mode 100644 index 0000000..166906f --- /dev/null +++ b/vllm/engine/multiprocessing/client.py @@ -0,0 +1,704 @@ +import asyncio +import copy +import pickle +from contextlib import contextmanager, suppress +from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping, + Optional, Union, overload) + +import cloudpickle +import zmq +import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket + +from vllm import PoolingParams +from vllm.beam_search import BeamSearchSequence, create_sort_beams_key_function +from vllm.config import DecodingConfig, EngineConfig, ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.async_llm_engine import ( + build_guided_decoding_logits_processor_async) +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, RPC_REQUEST_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.inputs import PromptType, TokensPrompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput, + RequestOutput) +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.utils import (collect_from_async_generator, deprecate_kwargs, + random_uuid) + +logger = init_logger(__name__) + + +class MQClientClosedError(Exception): + """Exception class raised when the client is used post-close. + + The client can be closed, which closes the ZMQ context. This normally + happens on server shutdown. In some cases, methods like abort and + do_log_stats will still be called and then try to open a socket, which + causes a ZMQError and creates a huge stack trace. + So, we throw this error such that we can suppress it. + """ + + +class MQLLMEngineClient: + """A client wrapper for MQLLMEngine that conforms to the + EngineClient protocol. + + MQLLMEngine and MQLLMEngineClient are intended to run in separate + processes communicating via zeromq ipc sockets. + + The entrypoint to MQLLMEngineClient is through the generate() + method. On generate() MQLLMEngine does three things: + - Creates an asyncio output queue + - Sends a RPCGenerateRequest to the MQLLMEngine via zmq + - Pulls RequestOutputs from its queue and yields them + + MQLLMEngine runs two background loops: + - output_loop: the output loop pulls List[RequestOutput] + from the MQLLMEngine via zmq (each list is the output + of one engine_step in the LLMEngine). It then parses + the list and pushes individual request_outputs into + the corresponding output_queue such that they can be + consumed by the .generate() method. + - health_loop: the health loop queries the health socket + every N seconds, confirming the engine is healthy + """ + + def __init__(self, ipc_path: str, engine_config: EngineConfig): + self.context = zmq.asyncio.Context() + self._errored_with: Optional[BaseException] = None + + # Get the configs. + self.model_config = engine_config.model_config + self.decoding_config = engine_config.decoding_config + + # Create the tokenizer group. + self.tokenizer = init_tokenizer_from_configs( + model_config=self.model_config, + scheduler_config=engine_config.scheduler_config, + parallel_config=engine_config.parallel_config, + enable_lora=bool(engine_config.lora_config), + ) + + # Send RPCGenerateRequest to the MQLLMEngine. + self.input_socket: Socket = self.context.socket(zmq.constants.PUSH) + self.input_socket.connect(f"{ipc_path}{IPC_INPUT_EXT}") + + # Receive streams of RequestOutput from the MQLLMEngine. + self.output_socket: Socket = self.context.socket(zmq.constants.PULL) + self.output_socket.connect(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # IPC path for acking heartbeats. + self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL) + self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Stream for each individual request. + self.output_queues: Dict[str, asyncio.Queue] = {} + self.output_loop = asyncio.create_task(self.run_output_handler_loop()) + + # Loop to check health of the LLMEngine periodically. + # Started after the MQLLMEngine is ready. + self.health_loop: Optional[asyncio.Task] = None + + @staticmethod + def is_unsupported_config(engine_args: AsyncEngineArgs): + # Pipeline parallel not yet supported + return engine_args.pipeline_parallel_size > 1 + + @contextmanager + def get_data_socket(self) -> Iterator[Socket]: + socket = self.context.socket(zmq.constants.DEALER) + try: + socket.connect(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + async def run_heartbeat_loop(self, timeout: int): + """Background loop that continually listens to the RPCServer for + heartbeats. + """ + try: + while True: + if await self.heartbeat_socket.poll(timeout=timeout) == 0: + # No heartbeat was received. Set error and exit the loop + self._set_errored( + TimeoutError("No heartbeat received " + "from MQLLMEngine")) + logger.debug("Shutting down MQLLMEngineClient check " + "health loop due to timeout") + break + + else: + # Heartbeat received- check the message + await self._check_success( + error_message="Heartbeat failed.", + socket=self.heartbeat_socket) + + logger.debug("Heartbeat successful.") + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient check health loop.") + + except Exception as e: + self._set_errored(e) + + async def run_output_handler_loop(self): + """Get RequestOutputs from Engine and stream to Request Queues""" + + try: + while True: + # Poll, checking for ENGINE_DEAD + while await self.output_socket.poll(timeout=VLLM_RPC_TIMEOUT + ) == 0: + logger.debug("Waiting for output from MQLLMEngine.") + + # If errored, alert all running requests. + if self.errored: + for queue_j in tuple(self.output_queues.values()): + queue_j.put_nowait( + ENGINE_DEAD_ERROR(self._errored_with)) + return + + message: Frame = await self.output_socket.recv(copy=False) + request_outputs = pickle.loads(message.buffer) + + is_error = isinstance(request_outputs, + (BaseException, RPCError)) + if is_error: + if isinstance(request_outputs, RPCError): + rpc_error: RPCError = request_outputs + request_id = rpc_error.request_id + exception = rpc_error.exception + is_engine_errored = rpc_error.is_engine_errored + else: + # MPLLMEngine should always return an RPCError to + # the output_socket when an issue arises. + # If we are here, we are in a bad state and + # should shut down the server. + error: BaseException = request_outputs + logger.error( + "Received Exception %s rather than RPCError from " + "MPLLMEngine. This should never happen.", error) + request_id = None + exception = error + is_engine_errored = True + + # Set to error state only on engine critical error + # (and record only the first one) + if is_engine_errored and not self._errored_with: + self._errored_with = exception + + if request_id is None: + for queue_i in tuple(self.output_queues.values()): + queue_i.put_nowait(exception) + else: + queue = self.output_queues.get(request_id) + if queue is not None: + queue.put_nowait(exception) + else: + # Put each output into the appropriate steam. + for request_output in request_outputs: + queue = self.output_queues.get( + request_output.request_id) + if queue is not None: + queue.put_nowait(request_output) + + except asyncio.CancelledError: + logger.debug("Shutting down MQLLMEngineClient output handler.") + + async def setup(self): + """Setup the client before it starts sending server requests.""" + + with self.get_data_socket() as socket: + # Wait until server is ready. + response = await self._wait_for_server_rpc(socket) + + self.tracing_flag = response.tracing_enabled + + # Start health_loop. + self.health_loop = asyncio.create_task( + self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT)) + + def close(self): + """Destroy the ZeroMQ Context.""" + # Close all sockets and terminate the context. + self.context.destroy(linger=0) + + # Cancel background tasks. + if self.health_loop is not None: + self.health_loop.cancel() + self.output_loop.cancel() + + def _set_errored(self, e: BaseException): + logger.exception(repr(e)) + if self._errored_with is None: + self._errored_with = e + + @staticmethod + async def _send_get_data_rpc_request(request: RPCStartupRequest, + expected_type: Any, + error_message: str, + socket: Socket) -> Any: + """Send an RPC request that is expecting data back.""" + + # Ping RPCServer with a request. + await socket.send_multipart((pickle.dumps(request), ), copy=False) + + # Make sure the server responds in time. + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("RPCServer didn't reply within " + f"{VLLM_RPC_TIMEOUT} ms") + + # Await the data from the Server. + frame = await socket.recv(copy=False) + data = pickle.loads(frame.buffer) + + if isinstance(data, BaseException): + raise data + elif not isinstance(data, expected_type): + raise ValueError(error_message) + + return data + + @staticmethod + async def _send_one_way_rpc_request(request: RPC_REQUEST_T, + socket: Socket): + """Send one-way RPC request to trigger an action.""" + + if socket.closed: + raise MQClientClosedError() + + await socket.send_multipart((pickle.dumps(request), )) + + async def _await_ack(self, error_message: str, socket: Socket): + """Await acknowledgement that a request succeeded.""" + + if socket.closed: + raise MQClientClosedError() + + if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: + raise TimeoutError("MQLLMEngine didn't reply within " + f"{VLLM_RPC_TIMEOUT}ms") + + await self._check_success(error_message, socket) + + @staticmethod + async def _check_success(error_message: str, socket: Socket): + """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" + + if socket.closed: + raise MQClientClosedError() + + frame = await socket.recv(copy=False) + response = pickle.loads(frame.buffer) + + # Raise error if unsuccessful + if isinstance(response, BaseException): + raise response + elif (not isinstance(response, str) + or response != VLLM_RPC_SUCCESS_STR): + raise ValueError(error_message) + + async def get_tokenizer(self, lora_request: LoRARequest): + return await self.tokenizer.get_lora_tokenizer_async(lora_request) + + async def get_decoding_config(self) -> DecodingConfig: + return self.decoding_config + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def is_tracing_enabled(self) -> bool: + return self.tracing_flag + + async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: + """Wait for the RPCServer to start up.""" + + return await self._send_get_data_rpc_request( + request=RPCStartupRequest.IS_SERVER_READY, + expected_type=RPCStartupResponse, + error_message="Unable to start RPC Server", + socket=socket) + + async def abort(self, request_id: str): + """Send an ABORT_REQUEST signal to the RPC Server""" + + with suppress(MQClientClosedError): + await self._send_one_way_rpc_request( + request=RPCAbortRequest(request_id), socket=self.input_socket) + + async def do_log_stats(self): + """Ignore do_log_stats (handled on MQLLMEngine polling)""" + pass + + async def check_health(self): + """ + The check health loop probes the health status of the + Engine's health every N seconds and sets _errored_with + if the engine is unhealthy. + """ + if self._errored_with is not None: + raise self._errored_with + + @property + def is_running(self) -> bool: + return not self.errored + + @property + def is_stopped(self) -> bool: + return self.errored + + @property + def errored(self) -> bool: + return self._errored_with is not None + + @property + def dead_error(self) -> BaseException: + return ENGINE_DEAD_ERROR(self._errored_with) + + @overload # DEPRECATED + def generate( + self, + *, + inputs: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @overload + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def generate( + self, + prompt: Optional[PromptType] = None, + sampling_params: Optional[SamplingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None # DEPRECATED + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the + scheduling policy is not "priority". + """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and sampling_params is not None + and request_id is not None) + + return self._process_request(prompt, sampling_params, request_id, + lora_request, trace_headers, + prompt_adapter_request, priority) + + async def beam_search( + self, + prompt: Union[PromptType, List[int]], + request_id: str, + params: BeamSearchParams, + ) -> AsyncGenerator[RequestOutput, None]: + + beam_width = params.beam_width + max_tokens = params.max_tokens + ignore_eos = params.ignore_eos + temperature = params.temperature + length_penalty = params.length_penalty + + tokenizer = await self.get_tokenizer(lora_request=None) + tokenizedPrompt = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + tokenizedLength = len(tokenizedPrompt) + + sort_beams_key = create_sort_beams_key_function( + tokenizer.eos_token_id, length_penalty) + + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)] + completed = [] + + for _ in range(max_tokens): + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + tasks = [] + + request_id = f"beam_search-{random_uuid()}" + for i, individual_prompt in enumerate(prompts_batch): + request_id_item = f"{request_id}-{i}" + task = asyncio.create_task( + collect_from_async_generator( + self.generate(individual_prompt, beam_search_params, + request_id_item))) + tasks.append(task) + + output = await asyncio.gather(*tasks) + + output = [x[0] for x in output] + + logger.info(output) + + new_beams = [] + for i, current_beam in enumerate(all_beams): + result = output[i] + + if result.outputs[0].logprobs is not None: + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + completed.append(new_beam) + else: + new_beams.append(new_beam) + + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) + all_beams = sorted_beams[:beam_width] + + completed.extend(all_beams) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens[tokenizedLength:]) + + beam_search_output = RequestOutput( + request_id=request_id, + prompt=prompt, + outputs=[ + CompletionOutput( + text=beam.text, + cumulative_logprob=beam.cum_logprob, + token_ids=beam.tokens, + index=i, + logprobs=beam.cum_logprob, + ) for (i, beam) in enumerate(best_beams) + ], + finished=True, + prompt_token_ids=tokenizedPrompt, + prompt_logprobs=None) + + logger.info(beam_search_output) + + yield beam_search_output + + @overload # DEPRECATED + def encode( + self, + *, + inputs: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @overload + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + ... + + @deprecate_kwargs( + "inputs", + additional_message="Please use the 'prompt' parameter instead.", + ) + def encode( + self, + prompt: Optional[PromptType] = None, + pooling_params: Optional[PoolingParams] = None, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + *, + inputs: Optional[PromptType] = None # DEPRECATED + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` + for more details about the format of each input. + pooling_params: The pooling parameters of the request. + request_id: The unique id of the request. + lora_request: LoRA request to use for generation, if any. + trace_headers: OpenTelemetry trace headers. + + Yields: + The output `EmbeddingRequestOutput` objects from the LLMEngine + for the request. + """ + if inputs is not None: + prompt = inputs + assert (prompt is not None and pooling_params is not None + and request_id is not None) + + return self._process_request(prompt, pooling_params, request_id, + lora_request, trace_headers, None, + priority) + + async def _process_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[ + EmbeddingRequestOutput, None]]: + """Send an RPCGenerateRequest to the RPCServer and stream responses.""" + + # If already dead, error out. + if self._errored_with is not None: + raise ENGINE_DEAD_ERROR(self._errored_with) + + # Constructing guided decoding logits processors is expensive, so we do + # it here to avoid contending with cpu resources and the GIL on the + # backend process. + if isinstance(params, SamplingParams) and \ + params.guided_decoding is not None: + params = await \ + build_guided_decoding_logits_processor_async( + sampling_params=params, + tokenizer=await self.get_tokenizer(lora_request), + default_guided_backend=self.decoding_config.guided_decoding_backend + ) + + # 1) Create output queue for this requests. + queue: asyncio.Queue[Union[RequestOutput, + BaseException]] = asyncio.Queue() + self.output_queues[request_id] = queue + + try: + # 2) Detach logits processors so that they can be pickled + # separately (may require cloudpickle which is slower) + if isinstance(params, SamplingParams) and params.logits_processors: + # Defensive shallow copy + params = copy.copy(params) + logits_processors = params.logits_processors + params.logits_processors = None + lp_bytes = cloudpickle.dumps(logits_processors) + else: + lp_bytes = None + + request_bytes = pickle.dumps( + RPCProcessRequest( + prompt=prompt, + params=params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + )) + + # 3) Send the RPCGenerateRequest to the MQLLMEngine. + parts = (request_bytes, + lp_bytes) if lp_bytes else (request_bytes, ) + await self.input_socket.send_multipart(parts, copy=False) + + # 4) Stream the RequestOutputs from the output queue. Note + # that the output_loop pushes RequestOutput objects to this + # queue after pulling them from the zmq socket. + finished = False + try: + while not finished: + request_output = await queue.get() + + if isinstance(request_output, BaseException): + raise request_output + + finished = request_output.finished + yield request_output + finally: + # Request was canceled by the client. + if not finished and not self.errored: + await self.abort(request_id) + finally: + self.output_queues.pop(request_id) + + async def start_profile(self) -> None: + """Start profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + + async def stop_profile(self) -> None: + """Stop profiling the engine""" + + await self._send_one_way_rpc_request( + request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py new file mode 100644 index 0000000..2bf0ce8 --- /dev/null +++ b/vllm/engine/multiprocessing/engine.py @@ -0,0 +1,395 @@ +import pickle +import signal +import threading +import time +from contextlib import contextmanager +from typing import Iterator, List, Optional, Union + +import cloudpickle +import zmq + +from vllm import AsyncEngineArgs, LLMEngine, SamplingParams +from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, + IPC_HEALTH_EXT, IPC_INPUT_EXT, + IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, + VLLM_RPC_SUCCESS_STR, RPCAbortRequest, + RPCError, RPCProcessRequest, + RPCStartupRequest, RPCStartupResponse, + RPCUProfileRequest) +# yapf: enable +from vllm.envs import VLLM_RPC_TIMEOUT +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.usage.usage_lib import UsageContext + +CONFIG_TYPE = Union[ModelConfig, DecodingConfig, ParallelConfig, + SchedulerConfig, LoRAConfig] + +logger = init_logger(__name__) + +POLLING_TIMEOUT_MS = 10000 +HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) + + +class MQLLMEngine: + """A multiprocessing wrapper for :class:`LLMEngine`. + + This class is used to wrap the :class:`LLMEngine` class to enable use + in concurrnet manner. It runs a background loop and uses zeromq to + receive new requests and stream outputs incrementally via ipc. + + The :class:`LLMEngine` generate or encode process is kicked off when a new + RPCProcessRequest is received by the input_socket. + + The self.engine_loop checks the input_socket for new requests, + adds them to the LLMEngine if there are any, calls the internal + :class:`LLMEngine.step()`, and sends the RequestOutputs back over + the output_socket. + + If use_async_sockets is set, the logic associated with reading new + requests from the socket and sending data to the socket is passed + as a callback to the llm_engine, which calls the logic asynchronously + such that the IPC can be overlapped with the GPU. + + Args: + ipc_path: Base path for zeromq interprocess messaging + use_async_sockets: Whether to make send/recv async with GPU + log_requests: Whether to log the requests. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. + """ + + def __init__(self, + ipc_path: str, + use_async_sockets: bool, + *args, + log_requests: bool = True, + **kwargs) -> None: + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + use_cached_outputs = True + + self.engine = LLMEngine(*args, + **kwargs, + use_cached_outputs=use_cached_outputs) + self.log_requests = log_requests + + self.use_async_sockets = use_async_sockets + if self.use_async_sockets: + self.engine.process_request_outputs_callback = \ + self._async_socket_engine_callback + + self.ctx = zmq.Context() # type: ignore[attr-defined] + + # Receive input from the client. + self.input_socket = self.ctx.socket(zmq.constants.PULL) + self.input_socket.bind(f"{ipc_path}{IPC_INPUT_EXT}") + + # Send output stream back to client. + self.output_socket = self.ctx.socket(zmq.constants.PUSH) + self.output_socket.bind(f"{ipc_path}{IPC_OUTPUT_EXT}") + + # Send heartbeats back to client. + self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH) + self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}") + + # IPC path for the data socket. + self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}" + + # Error state. + self._errored_with: Optional[BaseException] = None + + # Heartbeat thread + self.heartbeat_thread = threading.Thread(target=self._heartbeat_loop, + daemon=True) + self._heartbeat_stop_event = threading.Event() + # The heartbeat needs to be faster than what the client will wait for + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.heartbeat_interval_seconds = VLLM_RPC_TIMEOUT / 5000.0 + + self._last_alive_time = time.time() + # The heartbeats can tolerate a long period of the engine chugging + # away at a generation request. + # The VLLM_RPC_TIMEOUT duration is in ms, and we need one in seconds + self.last_alive_threshold = VLLM_RPC_TIMEOUT * 3.0 / 1000.0 + + @property + def dead_error(self) -> BaseException: + if self._errored_with is not None: + return ENGINE_DEAD_ERROR(self._errored_with) + else: + return ENGINE_DEAD_ERROR() + + @classmethod + def from_engine_args(cls, engine_args: AsyncEngineArgs, + usage_context: UsageContext, ipc_path: str): + """Creates an MQLLMEngine from the engine arguments.""" + # Setup plugins for each process + from vllm.plugins import load_general_plugins + load_general_plugins() + + engine_config = engine_args.create_engine_config() + + executor_class = LLMEngine._get_executor_cls(engine_config) + + return cls( + ipc_path=ipc_path, + use_async_sockets=engine_config.model_config.use_async_output_proc, + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context) + + def start(self): + try: + try: + logger.debug("Starting Startup Loop.") + self.run_startup_loop() + logger.debug("Starting heartbeat thread") + self.heartbeat_thread.start() + logger.debug("Starting Engine Loop.") + self.run_engine_loop() + except Exception as e: + logger.exception(repr(e)) + except KeyboardInterrupt: + logger.debug("Shutting down MQLLMEngine.") + finally: + logger.debug("MQLLMEngine is shut down.") + self.cleanup() + + def cleanup(self): + """Cleanup zeromq state on shutdown.""" + # Closes all sockets and destroys context. + self._heartbeat_stop_event.set() + self.ctx.destroy(linger=0) + del self.engine + + @contextmanager + def make_data_socket( + self) -> Iterator[zmq.Socket]: # type: ignore[name-defined] + socket = self.ctx.socket(zmq.constants.ROUTER) + try: + socket.bind(self.data_ipc_path) + yield socket + finally: + socket.close(linger=0) + + def run_startup_loop(self) -> None: + """Startup loop for sending data from Engine -> Client.""" + + with self.make_data_socket() as socket: + response: Union[RPCStartupResponse, BaseException] + try: + identity, message = socket.recv_multipart(copy=False) + request: RPCStartupRequest = pickle.loads(message.buffer) + + # Handle the query from the Client. + if request == RPCStartupRequest.IS_SERVER_READY: + tracing_enabled = self.engine.is_tracing_enabled() + response = RPCStartupResponse( + tracing_enabled=tracing_enabled) + + except Exception as e: + response = e + + socket.send_multipart((identity, pickle.dumps(response)), + copy=False) + + def run_engine_loop(self): + """Core busy loop of the LLMEngine.""" + + while True: + self._alive() + if not self.engine.has_unfinished_requests(): + # Poll until there is work to do. + while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: + self._alive() + self.engine.do_log_stats() + logger.debug("Waiting for new requests in engine loop.") + + # Handle any input from the client. + self.handle_new_input() + + # Engine step. + request_outputs = self.engine_step() + + # Send request outputs (if async, done in engine_step callback). + if not self.use_async_sockets: + self._send_outputs(request_outputs) + + def engine_step(self) -> List[RequestOutput]: + """Engine step wrapper with error handling.""" + try: + return self.engine.step() + except SystemExit: + raise + except BaseException as e: + self._set_errored(e) + rpc_err = RPCError(request_id=None, + is_engine_errored=True, + exception=e) + self._send_outputs(rpc_err) + raise e + + def handle_new_input(self): + """Handle new input from the socket""" + try: + while self.input_socket.poll(timeout=0) != 0: + frames = self.input_socket.recv_multipart(copy=False) + request = pickle.loads(frames[0].buffer) + + if isinstance(request, RPCProcessRequest): + if len(frames) > 1: + # Use cloudpickle for logits processors + assert isinstance(request.params, SamplingParams) + lprocs = cloudpickle.loads(frames[1].buffer) + request.params.logits_processors = lprocs + self._handle_process_request(request) + elif isinstance(request, RPCAbortRequest): + self._handle_abort_request(request) + elif isinstance(request, RPCUProfileRequest): + if request == RPCUProfileRequest.START_PROFILE: + self.start_profile() + else: + self.stop_profile() + else: + raise ValueError("Unknown RPCRequest Type: " + f"{type(request)}") + + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + raise e + + def _handle_process_request(self, request: RPCProcessRequest): + """Handle RPCProcessRequest by adding it to the LLMEngine.""" + request_id = request.request_id + + if self._errored_with is not None: + rpc_err = RPCError(request_id=request_id, + is_engine_errored=True, + exception=ENGINE_DEAD_ERROR(self._errored_with)) + self._send_outputs(rpc_err) + + try: + self.engine.add_request( + request_id=request_id, + prompt=request.prompt, + params=request.params, + lora_request=request.lora_request, + trace_headers=request.trace_headers, + prompt_adapter_request=request.prompt_adapter_request, + priority=request.priority) + + if self.log_requests: + logger.info("Added request %s.", request.request_id) + + except Exception as e: + # We do not set self._errored = True here, since the error + # is due to an issue adding this request to the engine, + # rather than an issue with the engine itself. + is_errored = self._errored_with is not None + rpc_err = RPCError(request_id=request_id, + is_engine_errored=is_errored, + exception=e) + self._send_outputs(rpc_err) + + # Remove request from the engine. + self.engine.abort_request(request_id) + + def _handle_abort_request(self, request: RPCAbortRequest): + self.engine.abort_request(request.request_id) + if self.log_requests: + logger.info("Aborted request %s.", request.request_id) + + def _heartbeat_loop(self): + while not self._heartbeat_stop_event.wait( + timeout=self.heartbeat_interval_seconds): + # Loops until the stop event is set + self._heartbeat() + + logger.debug("Exiting MQLLMEngine heartbeat thread") + + def _heartbeat(self): + # Send unhealthy if engine has already errored + if self._errored_with is not None: + self._send_unhealthy(self._errored_with) + + # Check for life of the main loop + elif time.time() - self._last_alive_time > self.last_alive_threshold: + self._send_unhealthy(RuntimeError("Engine loop has died")) + + else: + # Otherwise- check health of the engine + # self.engine.check_health() raises on unhealthy + try: + self.engine.check_health() + self._send_healthy() + except Exception as e: + self._set_errored(e) + self._send_unhealthy(e) + + def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): + """Send List of RequestOutput to RPCClient.""" + if outputs: + output_bytes = pickle.dumps(outputs) + self.output_socket.send_multipart((output_bytes, ), copy=False) + + def _send_healthy(self): + """Send HEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + self.heartbeat_socket.send_multipart(HEALTHY_RESPONSE, copy=False) + + def _send_unhealthy(self, error: BaseException): + """Send UNHEALTHY message to RPCClient.""" + if not self.heartbeat_socket.closed: + error_bytes = pickle.dumps(error) + self.heartbeat_socket.send_multipart((error_bytes, ), copy=False) + + def _async_socket_engine_callback(self, + request_outputs: REQUEST_OUTPUTS_T): + """Callback used by engine to make socket handling async with GPU.""" + self._send_outputs(request_outputs) + self.handle_new_input() + + def _set_errored(self, e: BaseException): + """Log and set errored status if this is the first issue.""" + if self._errored_with is None: + self._errored_with = e + + def _alive(self): + self._last_alive_time = time.time() + + def start_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.start_profile() + else: + self.engine.model_executor._run_workers("start_profile") + + def stop_profile(self) -> None: + if type(self.engine.model_executor) is GPUExecutor: + self.engine.model_executor.stop_profile() + else: + self.engine.model_executor._run_workers("stop_profile") + + +def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, + ipc_path: str): + + def signal_handler(*_) -> None: + # Interrupt server on sigterm + raise KeyboardInterrupt("MQLLMEngine terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + engine = MQLLMEngine.from_engine_args(engine_args=engine_args, + usage_context=usage_context, + ipc_path=ipc_path) + engine.start() diff --git a/vllm/engine/output_processor/__init__.py b/vllm/engine/output_processor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/output_processor/__pycache__/__init__.cpython-310.pyc b/vllm/engine/output_processor/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6280127 Binary files /dev/null and b/vllm/engine/output_processor/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/engine/output_processor/__pycache__/interfaces.cpython-310.pyc b/vllm/engine/output_processor/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000..ca6075f Binary files /dev/null and b/vllm/engine/output_processor/__pycache__/interfaces.cpython-310.pyc differ diff --git a/vllm/engine/output_processor/__pycache__/multi_step.cpython-310.pyc b/vllm/engine/output_processor/__pycache__/multi_step.cpython-310.pyc new file mode 100644 index 0000000..5ab7771 Binary files /dev/null and b/vllm/engine/output_processor/__pycache__/multi_step.cpython-310.pyc differ diff --git a/vllm/engine/output_processor/__pycache__/single_step.cpython-310.pyc b/vllm/engine/output_processor/__pycache__/single_step.cpython-310.pyc new file mode 100644 index 0000000..c447d33 Binary files /dev/null and b/vllm/engine/output_processor/__pycache__/single_step.cpython-310.pyc differ diff --git a/vllm/engine/output_processor/__pycache__/stop_checker.cpython-310.pyc b/vllm/engine/output_processor/__pycache__/stop_checker.cpython-310.pyc new file mode 100644 index 0000000..f26ea79 Binary files /dev/null and b/vllm/engine/output_processor/__pycache__/stop_checker.cpython-310.pyc differ diff --git a/vllm/engine/output_processor/__pycache__/util.cpython-310.pyc b/vllm/engine/output_processor/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..2918745 Binary files /dev/null and b/vllm/engine/output_processor/__pycache__/util.cpython-310.pyc differ diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py new file mode 100644 index 0000000..50adaf4 --- /dev/null +++ b/vllm/engine/output_processor/interfaces.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from typing import Callable, List + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + + +class SequenceGroupOutputProcessor(ABC): + """Interface for logic that processes new token ids in sequence groups, + managing detokenization, stop checking, and freeing/forking sequences with + the scheduler. + + This is highly coupled with the LLMEngine and should be seen as an extension + of it. The logic is separated to simplify the LLMEngine class and allow + separate implementations for single-step decoding (which supports beam + search sequence forking) and multi-step decoding (which does not support + beam search, but does support speculative decoding). + """ + + @staticmethod + def create_output_processor( + scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: "StopChecker", + ): + """Create an output processor. + + This returns a single-step output processor if num_lookahead_slots is + zero, else returns a multi-step output processor. + """ + if scheduler_config.num_lookahead_slots == 0: + # Importing here to avoid cycle. + from vllm.engine.output_processor.single_step import ( + SingleStepOutputProcessor) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) + else: + # Importing here to avoid cycle. + from vllm.engine.output_processor.multi_step import ( + MultiStepOutputProcessor) + return MultiStepOutputProcessor( + detokenizer, + scheduler, + seq_counter, + get_tokenizer_for_seq, + stop_checker, + ) + + @abstractmethod + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Process new token ids for the sequence group. Handles logic such as + detokenization, stop checking, and freeing/forking sequences in the + scheduler. + """ + pass + + @abstractmethod + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Update prompt logprobs received from outputs to seq_group.""" + pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py new file mode 100644 index 0000000..74ddb25 --- /dev/null +++ b/vllm/engine/output_processor/multi_step.py @@ -0,0 +1,188 @@ +import functools +from typing import Callable, List + +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +class MultiStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles logic related to + detokenization and stopping conditions. It specializes to "multi-step + decoding", where vLLM's worker may generate multiple tokens per invocation. + This is currently mutually exclusive with advanced sampling techniques like + beam search, which motivates the separation of this logic from the single + step output processor. + + This class is responsible for things such as correctly appending all new + token ids to their sequence, detokenizing new token ids, truncating new + output tokens after an eos token, and correctly handling the case where the + number of new output tokens per sequence differs in a single batch. + """ + + def __init__( + self, + detokenizer: Detokenizer, + scheduler: List[Scheduler], + seq_counter: Counter, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], + stop_checker: StopChecker, + ): + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.get_tokenizer_for_seq = get_tokenizer_for_seq + self.stop_checker = stop_checker + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this :class:`SequenceGroup` + outputs: the :class:`SequenceGroupOutput`s for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + single_step_process_prompt_logprob(self, seq_group, output) + + @staticmethod + @functools.lru_cache() + def _log_prompt_logprob_unsupported_warning_once(): + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + logger.warning( + "Prompt logprob is not supported by multi step workers. " + "(e.g., speculative decode uses multi step workers).") + + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: + """Append new tokens in the outputs to sequences in the sequence group. + + This only supports sequence groups of size 1. It supports greater than + one new token per sequence. + + This applies logic like stop condition checking and detokenization. + It also handles cases where there are tokens emitted after + the EOS token. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + # Sequences can be in RUNNING or FINISHED_ABORTED state + # once scheduled, as a sequence is moved to FINSIHED_ABORTED + # if a client disconnects from the api server. + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + if seqs is None: + seqs = sequence_group.get_seqs( + status=SequenceStatus.FINISHED_ABORTED) + + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" + assert len(seqs) == 1, ( + "Beam search not supported in multi-step decoding.") + seq = seqs[0] + seq_id = seq.seq_id + assert all( + [seq_id == output.samples[0].parent_seq_id for output in outputs]) + + if is_async: + # Async case: We process tokens one by one. Here, we know the token + # was already appended, so we only need to do the rest of the + # postprocessor: Detokenization + stopping logic + self._process_decode_and_stop(seq, sequence_group.sampling_params) + else: + # Standard multi-step case + + # Since there's only one sequence per sequence group, + # we can take the first sample. + samples = [output.samples[0] for output in outputs] + + # entries in sample tokens may be invalid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples + if sample.output_token != VLLM_INVALID_TOKEN_ID + ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_decode_and_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) + + def _process_seq_outputs(self, seq: Sequence, + valid_samples: List[SequenceOutput], + sampling_params: SamplingParams) -> None: + output_token_ids = [sample.output_token for sample in valid_samples] + output_logprobs = [sample.logprobs for sample in valid_samples] + + # Truncate to max_tokens if necessary. + remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + + len(output_token_ids)) + if remaining_tokens < 0: + output_token_ids = output_token_ids[:remaining_tokens] + + # Truncate any tokens after EOS. This is required as spec decode + # generates a fixed number of tokens without evaluating stopping + # conditions within the block. This can cause an eos token to be + # unintentionally ignored. + if not sampling_params.ignore_eos: + eos_token_id = self.get_tokenizer_for_seq(seq).eos_token_id + # Avoiding .index calls as exception throwing in the happy path + # is expensive. + for i in range(len(output_token_ids)): + if output_token_ids[i] == eos_token_id: + output_token_ids = output_token_ids[:i + 1] + break + + is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 + # Incrementally append tokens to the sequence, as if we had only one new + # token. + for output_token_id, output_logprob in zip(output_token_ids, + output_logprobs): + seq.append_token_id( + token_id=output_token_id, + logprobs=output_logprob, + ) + + if is_prefill_sampled_token: + is_prefill_sampled_token = False + else: + # Update num_computed_tokens iff the sampled token is not from + # a prefill step. + seq.data.update_num_computed_tokens(1) + + self._process_decode_and_stop(seq, sampling_params) + + if seq.is_finished(): + break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py new file mode 100644 index 0000000..cfa8407 --- /dev/null +++ b/vllm/engine/output_processor/single_step.py @@ -0,0 +1,215 @@ +from typing import Dict, List, Tuple + +from vllm.config import SchedulerConfig +from vllm.core.scheduler import Scheduler +from vllm.engine.output_processor.interfaces import ( + SequenceGroupOutputProcessor) +from vllm.engine.output_processor.stop_checker import StopChecker +from vllm.logger import init_logger +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, + SequenceOutput, SequenceStatus) +from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.utils import Counter + +logger = init_logger(__name__) + + +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: SequenceGroupOutput) -> None: + """Process prompt logprobs associated with the :class:`SequenceGroupOutput` + for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: :class:`SequenceGroupOutputProcessor` instance + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + +class SingleStepOutputProcessor(SequenceGroupOutputProcessor): + """SequenceGroupOutputProcessor which handles "output processing" logic, + which happens after the model returns generated token ids and before + scheduling of the next batch. Output processing logic includes + detokenization, and determining if a sequence is finished (e.g. via max len + or eos token). + + The SingleStepOutputProcessor is specialized to the case where the model + emits at most a single token per invocation, which precludes configurations + such as speculative decoding or multi-step decoding. This enables beam + search sampling, which requires forking/finishing/freeing sequences in a way + that is currently difficult to schedule multiple steps ahead of time. + """ + + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): + self.scheduler_config = scheduler_config + self.detokenizer = detokenizer + self.scheduler = scheduler + self.seq_counter = seq_counter + self.stop_checker = stop_checker + + def process_outputs(self, sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: + """Append all new tokens to sequences in the sequence group. Fork any + surviving beam candidates; free any unsurviving ones. + + Invokes detokenizer to detokenize new tokens, and also marks sequences + as finished if they meet stop conditions. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + """ + assert (len(outputs) == 1 + ), f"{type(self)} does not support multiple outputs per step" + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) + + def process_prompt_logprob(self, seq_group: SequenceGroup, + outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ + assert len(outputs) == 1, ("Single step should only has 1 output.") + output = outputs[0] + single_step_process_prompt_logprob(self, seq_group, output) + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput, + is_async: bool) -> None: + sampling_params = seq_group.sampling_params + if sampling_params.n == 1: + # only have one output sample + sample = outputs.samples[0] + # only have one sequence + seq = seq_group.seqs[0] + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) + return + + # TODO: Add support for async for beam search + assert not is_async + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + parent_child_dict: Dict[int, List[SequenceOutput]] = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + # Guard against a KeyError which can occur if the request was + # aborted while the output was generated + if (child_list := + parent_child_dict.get(sample.parent_seq_id)) is not None: + child_list.append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + for scheduler in self.scheduler: + scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id: int = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + for scheduler in self.scheduler: + scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) + return diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py new file mode 100644 index 0000000..0c5f8fb --- /dev/null +++ b/vllm/engine/output_processor/stop_checker.py @@ -0,0 +1,117 @@ +from typing import Callable, Optional + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.sequence import Sequence, SequenceStatus +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class StopChecker: + """LLMEngine helper class which separates out the logic involving stop + checking. This checks things such as: whether the eos token was emitted, + whether the max_tokens has been consumed, whether a stop string has been + emitted, or if we have exceeded the max model len. + """ + + def __init__(self, max_model_len: int, + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): + # Do not use it directly, but use `self._get_max_model_len`. + self._max_model_len = max_model_len + self.get_tokenizer_for_seq = get_tokenizer_for_seq + + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): + if lora_req and lora_req.long_lora_max_len: + return lora_req.long_lora_max_len + else: + return self._max_model_len + + def maybe_stop_sequence( + self, + seq: Sequence, + new_char_count: int, + sampling_params: SamplingParams, + lora_req: Optional[LoRARequest] = None, + ) -> None: + """Stop the finished sequences. + + new_char_count is the number of chars added to the + sequence's output text for the newly generated token + """ + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == seq.eos_token_id): + # Remove the last EOS token unless explicitly specified + # This prevents unintended exposure of the EOS token + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if a stop token was encountered. + # This assumes a single token produced per step. + last_token_id = seq.get_last_token_id() + if last_token_id in sampling_params.stop_token_ids: + if new_char_count and ( + not sampling_params.include_stop_str_in_output): + # Remove last token + seq.output_text = seq.output_text[:-new_char_count] + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = last_token_id + return + + # Check if any stop strings are matched. + stop_str = self._check_stop_strings(seq, new_char_count, + sampling_params) + if stop_str is not None: + seq.status = SequenceStatus.FINISHED_STOPPED + seq.stop_reason = stop_str + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self._get_max_model_len(lora_req): + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + @staticmethod + def _check_stop_strings(seq: Sequence, new_char_count: int, + sampling_params: SamplingParams) -> Optional[str]: + """Check if any stop strings are matched and truncate sequence + output text accordingly. + + Returns the stop string if matched or else None. + """ + if not new_char_count: + return None + + for stop_str in sampling_params.stop: + stop_string_len = len(stop_str) + # Avoid searching already-searched text. + stop_index = seq.output_text.find( + stop_str, -new_char_count - stop_string_len) + if stop_index == -1: + continue + + if sampling_params.include_stop_str_in_output: + # Truncate to end of stop string. + stop_index += stop_string_len + if stop_index >= len(seq.output_text): + # No truncation required. + return stop_str + + # Truncate the output text to either the beginning + # or end of the stop string. + seq.output_text = seq.output_text[:stop_index] + return stop_str + return None diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py new file mode 100644 index 0000000..7678288 --- /dev/null +++ b/vllm/engine/output_processor/util.py @@ -0,0 +1,22 @@ +from typing import List +from typing import Sequence as GenericSequence +from typing import Union + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import PoolerOutput, SequenceGroupOutput + + +def create_output_by_sequence_group( + outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]], + num_seq_groups: int) -> List[List[SequenceGroupOutput]]: + """Helper method which transforms a 2d list organized by + [step][sequence group] into [sequence group][step]. + """ + output_by_sequence_group: List[List[SequenceGroupOutput]] = [ + [] for _ in range(num_seq_groups) + ] + for step in outputs: + for i, sequence_group_output in enumerate(step): + output_by_sequence_group[i].append(sequence_group_output) + + return output_by_sequence_group diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py new file mode 100644 index 0000000..d7ff743 --- /dev/null +++ b/vllm/engine/protocol.py @@ -0,0 +1,103 @@ +from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, + runtime_checkable) + +from vllm.config import DecodingConfig, ModelConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.inputs.data import PromptType +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +@runtime_checkable +class EngineClient(Protocol): + """Protocol class for Clients to Engine""" + + @property + def is_running(self) -> bool: + ... + + @property + def is_stopped(self) -> bool: + ... + + @property + def errored(self) -> bool: + ... + + @property + def dead_error(self) -> BaseException: + ... + + def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> AsyncGenerator[RequestOutput, None]: + """Generate outputs for a request.""" + ... + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: + """Generate outputs for a request from an embedding model.""" + ... + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Args: + request_id: The unique id of the request. + """ + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + ... + + async def get_decoding_config(self) -> DecodingConfig: + ... + """Get the decoding configuration of the vLLM engine.""" + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get the appropriate tokenizer for the request""" + ... + + async def is_tracing_enabled(self) -> bool: + ... + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + ) -> None: + ... + + async def check_health(self) -> None: + """Raise if unhealthy""" + ... + + async def start_profile(self) -> None: + """Start profiling the engine""" + ... + + async def stop_profile(self) -> None: + """Start profiling the engine""" + ... diff --git a/vllm/entrypoints/__init__.py b/vllm/entrypoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/__pycache__/__init__.cpython-310.pyc b/vllm/entrypoints/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..8efbea9 Binary files /dev/null and b/vllm/entrypoints/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/entrypoints/__pycache__/api_server.cpython-310.pyc b/vllm/entrypoints/__pycache__/api_server.cpython-310.pyc new file mode 100644 index 0000000..0704a52 Binary files /dev/null and b/vllm/entrypoints/__pycache__/api_server.cpython-310.pyc differ diff --git a/vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc b/vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc new file mode 100644 index 0000000..a1b01ce Binary files /dev/null and b/vllm/entrypoints/__pycache__/chat_utils.cpython-310.pyc differ diff --git a/vllm/entrypoints/__pycache__/launcher.cpython-310.pyc b/vllm/entrypoints/__pycache__/launcher.cpython-310.pyc new file mode 100644 index 0000000..b8dfa82 Binary files /dev/null and b/vllm/entrypoints/__pycache__/launcher.cpython-310.pyc differ diff --git a/vllm/entrypoints/__pycache__/llm.cpython-310.pyc b/vllm/entrypoints/__pycache__/llm.cpython-310.pyc new file mode 100644 index 0000000..0f22fd2 Binary files /dev/null and b/vllm/entrypoints/__pycache__/llm.cpython-310.pyc differ diff --git a/vllm/entrypoints/__pycache__/logger.cpython-310.pyc b/vllm/entrypoints/__pycache__/logger.cpython-310.pyc new file mode 100644 index 0000000..0df3711 Binary files /dev/null and b/vllm/entrypoints/__pycache__/logger.cpython-310.pyc differ diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py new file mode 100644 index 0000000..f3e80ca --- /dev/null +++ b/vllm/entrypoints/api_server.py @@ -0,0 +1,163 @@ +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" +import asyncio +import json +import ssl +from argparse import Namespace +from typing import Any, AsyncGenerator, Optional + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.launcher import serve_http +from vllm.logger import init_logger +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, + random_uuid) +from vllm.version import __version__ as VLLM_VERSION + +logger = init_logger("vllm.entrypoints.api_server") + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = iterate_with_cancellation( + results_generator, is_cancelled=request.is_disconnected) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + assert prompt is not None + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) + + assert final_output is not None + prompt = final_output.prompt + assert prompt is not None + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def build_app(args: Namespace) -> FastAPI: + global app + + app.root_path = args.root_path + return app + + +async def init_app( + args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, +) -> FastAPI: + app = build_app(args) + + global engine + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = (llm_engine + if llm_engine is not None else AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.API_SERVER)) + + return app + + +async def run_server(args: Namespace, + llm_engine: Optional[AsyncLLMEngine] = None, + **uvicorn_kwargs: Any) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + app = await init_app(args, llm_engine) + assert engine is not None + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + await shutdown_task + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", + type=str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + asyncio.run(run_server(args)) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py new file mode 100644 index 0000000..41354dc --- /dev/null +++ b/vllm/entrypoints/chat_utils.py @@ -0,0 +1,581 @@ +import asyncio +import codecs +import json +from abc import ABC, abstractmethod +from collections import defaultdict +from functools import lru_cache, partial +from pathlib import Path +from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, + Mapping, Optional, Tuple, TypeVar, Union, cast) + +# yapf conflicts with isort for this block +# yapf: disable +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam) +from openai.types.chat import ( + ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) +from openai.types.chat import ( + ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) +# yapf: enable +# pydantic needs the TypedDict from typing_extensions +from pydantic import ConfigDict +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing_extensions import Required, TypeAlias, TypedDict + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import (async_get_and_parse_audio, + async_get_and_parse_image, + get_and_parse_audio, get_and_parse_image) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer + +logger = init_logger(__name__) + + +class AudioURL(TypedDict, total=False): + url: Required[str] + """ + Either a URL of the audio or a data URL with base64 encoded audio data. + """ + + +class ChatCompletionContentPartAudioParam(TypedDict, total=False): + audio_url: Required[AudioURL] + + type: Required[Literal["audio_url"]] + """The type of the content part.""" + + +class CustomChatCompletionContentPartParam(TypedDict, total=False): + __pydantic_config__ = ConfigDict(extra="allow") # type: ignore + + type: Required[str] + """The type of the content part.""" + + +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPartParam] + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, + CustomChatCompletionMessageParam] + + +# TODO: Make fields ReadOnly once mypy supports it +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Optional[str] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ModalityStr = Literal["image", "audio", "video"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + self._consumed_items = {k: 0 for k in self._allowed_items} + + self._items: List[_T] = [] + + @staticmethod + @lru_cache(maxsize=None) + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: + return tokenizer.decode(token_index) + + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + + if modality == "image": + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type == "minicpmv": + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma", + "pixtral"): + # These models do not use image tokens in the prompt + return None + if model_type == "qwen": + return f"Picture {current_count}: " + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) + if model_type in ("chameleon", "internvl_chat", "NVLM_D"): + return "" + if model_type == "mllama": + return "<|image|>" + if model_type == "qwen2_vl": + return "<|vision_start|><|image_pad|><|vision_end|>" + if model_type == "molmo": + return "" + + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "video": + if model_type == "qwen2_vl": + return "<|vision_start|><|video_pad|><|vision_end|>" + raise TypeError(f"Unknown model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + @staticmethod + def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: + mm_lists: Mapping[str, List[object]] = defaultdict(list) + + # Merge all the multi-modal items + for single_mm_data in items: + for mm_key, mm_item in single_mm_data.items(): + if isinstance(mm_item, list): + mm_lists[mm_key].extend(mm_item) + else: + mm_lists[mm_key].append(mm_item) + + # Unpack any single item lists for models that don't expect multiple. + return { + mm_key: mm_list[0] if len(mm_list) == 1 else mm_list + for mm_key, mm_list in mm_lists.items() + } + + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = self._consumed_items.get(modality, 0) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._consumed_items[modality] = current_count + self._items.append(item) + + return self._placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + return self._combine(self._items) if self._items else None + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker( + BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if self._items: + items = await asyncio.gather(*self._items) + return self._combine(items) + + return None + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> Dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + + +class MultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image = get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image_coro = async_get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = async_get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) + + +def validate_chat_template(chat_template: Optional[Union[Path, str]]): + """Raises if the provided chat template appears invalid.""" + if chat_template is None: + return + + elif isinstance(chat_template, Path) and not chat_template.exists(): + raise FileNotFoundError( + "the supplied chat template path doesn't exist") + + elif isinstance(chat_template, str): + JINJA_CHARS = "{}\n" + if not any(c in chat_template + for c in JINJA_CHARS) and not Path(chat_template).exists(): + raise ValueError( + f"The supplied chat template string ({chat_template}) " + f"appears path-like, but doesn't exist!") + + else: + raise TypeError( + f"{type(chat_template)} is not a valid chat template type") + + +def load_chat_template( + chat_template: Optional[Union[Path, str]]) -> Optional[str]: + if chat_template is None: + return None + try: + with open(chat_template, "r") as f: + resolved_chat_template = f.read() + except OSError as e: + if isinstance(chat_template, Path): + raise + + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + resolved_chat_template = codecs.decode(chat_template, "unicode_escape") + + logger.info("Using supplied chat template:\n%s", resolved_chat_template) + return resolved_chat_template + + +# TODO: Let user specify how to insert multimodal tokens into prompt +# (similar to chat template) +def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], + text_prompt: str) -> str: + """Combine multimodal prompts for a multimodal language model.""" + + # Look through the text prompt to check for missing placeholders + missing_placeholders: List[str] = [] + for placeholder in placeholder_counts: + + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") + + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) + + +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) +MODEL_KEEP_MULTI_MODAL_CONTENT = {'mllama'} + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: + texts: List[str] = [] + + mm_parser = mm_tracker.create_parser() + keep_multimodal_content = \ + mm_tracker._model_config.hf_config.model_type in \ + MODEL_KEEP_MULTI_MODAL_CONTENT + + has_image = False + for part in parts: + part_type = part["type"] + if part_type == "text": + text = _TextParser(part)["text"] + texts.append(text) + elif part_type == "image_url": + image_url = _ImageParser(part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + mm_parser.parse_image(image_url["url"]) + has_image = True + elif part_type == "audio_url": + audio_url = _AudioParser(part)["audio_url"] + + mm_parser.parse_audio(audio_url["url"]) + elif part_type == "refusal": + text = _RefusalParser(part)["refusal"] + texts.append(text) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + text_prompt = "\n".join(texts) + if keep_multimodal_content: + text_prompt = "\n".join(texts) + role_content = [{'type': 'text', 'text': text_prompt}] + + if has_image: + role_content = [{'type': 'image'}] + role_content + return [ConversationMessage(role=role, + content=role_content)] # type: ignore + else: + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt( + mm_placeholder_counts, text_prompt) + return [ConversationMessage(role=role, content=text_prompt)] + + +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) + + +def _parse_chat_message_content( + message: ChatCompletionMessageParam, + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: + role = message["role"] + content = message.get("content") + + if content is None: + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] + + result = _parse_chat_message_content_parts( + role, + content, # type: ignore + mm_tracker, + ) + + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + + +def _postprocess_messages(messages: List[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + +def parse_chat_messages( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: + conversation: List[ConversationMessage] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content(msg, mm_tracker) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def parse_chat_messages_futures( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content(msg, mm_tracker) + + conversation.extend(sub_messages) + + _postprocess_messages(conversation) + + return conversation, mm_tracker.all_mm_data() + + +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + conversation: List[ConversationMessage], + chat_template: Optional[str], + *, + tokenize: bool = False, # Different from HF's default + **kwargs: Any, +) -> str: + if chat_template is None and tokenizer.chat_template is None: + raise ValueError( + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one.") + + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) + + +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str] = None, + **kwargs: Any, +) -> List[int]: + if chat_template is not None: + logger.warning( + "'chat_template' cannot be overridden for mistral tokenizer.") + if "add_generation_prompt" in kwargs: + logger.warning( + "'add_generation_prompt' is not supported for mistral tokenizer, " + "so it will be ignored.") + if "continue_final_message" in kwargs: + logger.warning( + "'continue_final_message' is not supported for mistral tokenizer, " + "so it will be ignored.") + + return tokenizer.apply_chat_template( + messages=messages, + **kwargs, + ) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py new file mode 100644 index 0000000..5dcf50b --- /dev/null +++ b/vllm/entrypoints/launcher.py @@ -0,0 +1,103 @@ +import asyncio +import signal +from http import HTTPStatus +from typing import Any + +import uvicorn +from fastapi import FastAPI, Request, Response + +from vllm import envs +from vllm.engine.async_llm_engine import AsyncEngineDeadError +from vllm.engine.multiprocessing import MQEngineDeadError +from vllm.logger import init_logger +from vllm.utils import find_process_using_port + +logger = init_logger(__name__) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ', '.join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + _add_shutdown_handlers(app, server) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + port = uvicorn_kwargs["port"] + process = find_process_using_port(port) + if process is not None: + logger.debug( + "port %s is used by process %s launched with command:\n%s", + port, process, " ".join(process.cmdline())) + logger.info("Shutting down FastAPI HTTP server.") + return server.shutdown() + + +def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: + """Adds handlers for fatal errors that should crash the server""" + + @app.exception_handler(RuntimeError) + async def runtime_error_handler(request: Request, __): + """On generic runtime error, check to see if the engine has died. + It probably has, in which case the server will no longer be able to + handle requests. Trigger a graceful shutdown with a SIGTERM.""" + engine = request.app.state.engine_client + if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored + and not engine.is_running): + logger.fatal("AsyncLLMEngine has failed, terminating server " + "process") + # See discussions here on shutting down a uvicorn server + # https://github.com/encode/uvicorn/discussions/1103 + # In this case we cannot await the server shutdown here because + # this handler must first return to close the connection for + # this request. + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(AsyncEngineDeadError) + async def async_engine_dead_handler(_, __): + """Kill the server if the async engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("AsyncLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + + @app.exception_handler(MQEngineDeadError) + async def mq_engine_dead_handler(_, __): + """Kill the server if the mq engine is already dead. It will + not handle any further requests.""" + if not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: + logger.fatal("MQLLMEngine is already dead, terminating server " + "process") + server.should_exit = True + + return Response(status_code=HTTPStatus.INTERNAL_SERVER_ERROR) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py new file mode 100644 index 0000000..2010381 --- /dev/null +++ b/vllm/entrypoints/llm.py @@ -0,0 +1,909 @@ +import itertools +import warnings +from contextlib import contextmanager +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, + Union, cast, overload) + +from tqdm import tqdm + +from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, + BeamSearchSequence, get_beam_search_score) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, + apply_hf_chat_template, + apply_mistral_chat_template, + parse_chat_messages) +from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding.guided_fields import ( + GuidedDecodingRequest, LLMGuidedOptions) +from vllm.outputs import EmbeddingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, + get_cached_tokenizer) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter, deprecate_kwargs, is_list_of + +logger = init_logger(__name__) + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. Expect valid prompt_token_ids and None for prompt + from the input. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq", and "fp8" (experimental). + If None, we first check the `quantization_config` attribute in the + model config file. If that is None, we assume the model weights are + not quantized and use `dtype` to determine the data type of + the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + cpu_offload_gb: The size (GiB) of CPU memory to use for offloading + the model weights. This virtually increases the GPU memory space + you can use to hold the model weights, at the cost of CPU-GPU data + transfer for every forward pass. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). + max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. Additionally for encoder-decoder models, if the + sequence length of the encoder input is larger than this, we fall + back to the eager mode. + disable_custom_all_reduce: See ParallelConfig + **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See + :ref:`engine_args`) + + Note: + This class is intended to be used for offline inference. For online + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + """ + + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @classmethod + @contextmanager + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True + + yield + + cls.DEPRECATE_LEGACY = False + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + engine_args = EngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) + self.request_counter = Counter() + + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + + def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: + tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) + + # While CachedTokenizer is dynamic, have no choice but + # compare class name. Misjudgment will arise from + # user-defined tokenizer started with 'Cached' + if tokenizer.__class__.__name__.startswith("Cached"): + tokenizer_group.tokenizer = tokenizer + else: + tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) + + @overload # LEGACY: single (prompt + optional token ids) + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + def generate( + self, + prompts: List[str], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[RequestOutput]: + ... + + @overload + def generate( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def generate( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + guided_options_request: Optional[Union[LLMGuidedOptions, + GuidedDecodingRequest]] = None, + priority: Optional[List[int]] = None, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + When it is a single value, it is applied to every prompt. + When it is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + priority: The priority of the requests, if any. + Only applicable when priority scheduling policy is enabled. + + Returns: + A list of ``RequestOutput`` objects containing the + generated completions in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + if self.llm_engine.model_config.embedding_mode: + raise ValueError( + "LLM.generate() is only supported for (conditional) generation " + "models (XForCausalLM, XForConditionalGeneration).") + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if isinstance(guided_options_request, dict): + if len(guided_options_request) > 1: + raise ValueError( + "You can only use one guided decoding but multiple is " + f"specified: {guided_options_request}") + guided_options_request = GuidedDecodingRequest( + **guided_options_request) + + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + guided_options=guided_options_request, + priority=priority) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) + + def beam_search( + self, + prompts: List[Union[str, List[int]]], + params: BeamSearchParams, + ) -> List[BeamSearchOutput]: + """ + Generate sequences using beam search. + + Args: + prompts: A list of prompts. Each prompt can be a string or a list + of token IDs. + params: The beam search parameters. + + TODO: how does beam search work together with length penalty, frequency + penalty, and stopping criteria, etc.? + """ + + beam_width = params.beam_width + max_tokens = params.max_tokens + temperature = params.temperature + ignore_eos = params.ignore_eos + length_penalty = params.length_penalty + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, + length_penalty) + + tokenizer = self.get_tokenizer() + # generate 2 * beam_width candidates at each step + # following the huggingface transformers implementation + # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa + beam_search_params = SamplingParams(logprobs=2 * beam_width, + max_tokens=1, + temperature=temperature) + instances: List[BeamSearchInstance] = [] + + for prompt in prompts: + prompt_tokens = prompt if isinstance( + prompt, list) else tokenizer.encode(prompt) + instances.append(BeamSearchInstance(prompt_tokens)) + + for _ in range(max_tokens): + all_beams: List[BeamSearchSequence] = list( + sum((instance.beams for instance in instances), [])) + pos = [0] + list( + itertools.accumulate( + len(instance.beams) for instance in instances)) + instance_start_and_end: List[Tuple[int, int]] = list( + zip(pos[:-1], pos[1:])) + + if len(all_beams) == 0: + break + + prompts_batch = [ + TokensPrompt(prompt_token_ids=beam.tokens) + for beam in all_beams + ] + + # only runs for one step + # we don't need to use tqdm here + output = self.generate(prompts_batch, + sampling_params=beam_search_params, + use_tqdm=False) + + for (start, end), instance in zip(instance_start_and_end, + instances): + instance_new_beams = [] + for i in range(start, end): + current_beam = all_beams[i] + result = output[i] + + if result.outputs[0].logprobs is not None: + # if `result.outputs[0].logprobs` is None, it means + # the sequence is completed because of the max-model-len + # or abortion. we don't need to add it to the new beams. + logprobs = result.outputs[0].logprobs[0] + for token_id, logprob_obj in logprobs.items(): + new_beam = BeamSearchSequence( + tokens=current_beam.tokens + [token_id], + cum_logprob=current_beam.cum_logprob + + logprob_obj.logprob) + + if token_id == tokenizer.eos_token_id and \ + not ignore_eos: + instance.completed.append(new_beam) + else: + instance_new_beams.append(new_beam) + sorted_beams = sorted(instance_new_beams, + key=sort_beams_key, + reverse=True) + instance.beams = sorted_beams[:beam_width] + + outputs = [] + for instance in instances: + instance.completed.extend(instance.beams) + sorted_completed = sorted(instance.completed, + key=sort_beams_key, + reverse=True) + best_beams = sorted_completed[:beam_width] + + for beam in best_beams: + beam.text = tokenizer.decode(beam.tokens) + outputs.append(BeamSearchOutput(sequences=best_beams)) + + return outputs + + def chat( + self, + messages: Union[List[ChatCompletionMessageParam], + List[List[ChatCompletionMessageParam]]], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + chat_template: Optional[str] = None, + add_generation_prompt: bool = True, + continue_final_message: bool = False, + tools: Optional[List[Dict[str, Any]]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ) -> List[RequestOutput]: + """ + Generate responses for a chat conversation. + + The chat conversation is converted into a text prompt using the + tokenizer and calls the :meth:`generate` method to generate the + responses. + + Multi-modal inputs can be passed in the same way you would pass them + to the OpenAI API. + + Args: + messages: A list of conversations or a single conversation. + - Each conversation is represented as a list of messages. + - Each message is a dictionary with 'role' and 'content' keys. + sampling_params: The sampling parameters for text generation. + If None, we use the default sampling parameters. When it + is a single value, it is applied to every prompt. When it + is a list, the list must have the same length as the + prompts and it is paired one by one with the prompt. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + chat_template: The template to use for structuring the chat. + If not provided, the model's default chat template will be used. + add_generation_prompt: If True, adds a generation template + to each message. + continue_final_message: If True, continues the final message in + the conversation instead of starting a new one. Cannot be `True` + if `add_generation_prompt` is also `True`. + mm_processor_kwargs: Multimodal processor kwarg overrides for this + chat request. Only used for offline requests. + + Returns: + A list of ``RequestOutput`` objects containing the generated + responses in the same order as the input messages. + """ + list_of_messages: List[List[ChatCompletionMessageParam]] + + # Handle multi and single conversations + if is_list_of(messages, list): + # messages is List[List[...]] + list_of_messages = cast(List[List[ChatCompletionMessageParam]], + messages) + else: + # messages is List[...] + list_of_messages = [ + cast(List[ChatCompletionMessageParam], messages) + ] + + prompts: List[Union[TokensPrompt, TextPrompt]] = [] + + for msgs in list_of_messages: + tokenizer = self.get_tokenizer() + model_config = self.llm_engine.get_model_config() + + # NOTE: _parse_chat_message_content_parts() currently doesn't + # handle mm_processor_kwargs, since there is no implementation in + # the chat message parsing for it. + conversation, mm_data = parse_chat_messages( + msgs, model_config, tokenizer) + + prompt_data: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt_data = apply_mistral_chat_template( + tokenizer, + messages=msgs, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + else: + prompt_data = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + tools=tools, + ) + + prompt: Union[TokensPrompt, TextPrompt] + if is_list_of(prompt_data, int): + prompt = TokensPrompt(prompt_token_ids=prompt_data) + else: + prompt = TextPrompt(prompt=prompt_data) + + if mm_data is not None: + prompt["multi_modal_data"] = mm_data + + if mm_processor_kwargs is not None: + prompt["mm_processor_kwargs"] = mm_processor_kwargs + + prompts.append(prompt) + + return self.generate( + prompts, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + @overload # LEGACY: single (prompt + optional token ids) + def encode( + self, + prompts: str, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (prompt + optional token ids) + def encode( + self, + prompts: List[str], + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single (token ids + optional prompt) + def encode( + self, + prompts: Optional[str] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: multi (token ids + optional prompt) + def encode( + self, + prompts: Optional[List[str]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload # LEGACY: single or multi token ids [pos-only] + def encode( + self, + prompts: None, + pooling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @overload + def encode( + self, + prompts: Union[PromptType, Sequence[PromptType]], + /, + *, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + ) -> List[EmbeddingRequestOutput]: + ... + + @deprecate_kwargs( + "prompt_token_ids", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'prompts' parameter instead.", + ) + def encode( + self, + prompts: Union[Union[PromptType, Sequence[PromptType]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[PoolingParams, + Sequence[PoolingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates the completions for the input prompts. + + This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: The prompts to the LLM. You may pass a sequence of prompts + for batch inference. See :class:`~vllm.inputs.PromptType` + for more details about the format of each prompts. + pooling_params: The pooling parameters for pooling. If None, we + use the default pooling parameters. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of `EmbeddingRequestOutput` objects containing the + generated embeddings in the same order as the input prompts. + + Note: + Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is + considered legacy and may be deprecated in the future. You should + instead pass them via the ``inputs`` parameter. + """ + if not self.llm_engine.model_config.embedding_mode: + raise ValueError( + "LLM.encode() is only supported for embedding models (XModel)." + ) + + if prompt_token_ids is not None: + parsed_prompts = self._convert_v1_inputs( + prompts=cast(Optional[Union[str, List[str]]], prompts), + prompt_token_ids=prompt_token_ids, + ) + else: + parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], + prompts) + + if pooling_params is None: + # Use default pooling params. + pooling_params = PoolingParams() + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + + def start_profile(self) -> None: + self.llm_engine.start_profile() + + def stop_profile(self) -> None: + self.llm_engine.stop_profile() + + # LEGACY + def _convert_v1_inputs( + self, + prompts: Optional[Union[str, List[str]]], + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + ): + # skip_tokenizer_init is now checked in engine + + if prompts is not None: + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + + num_requests = None + if prompts is not None: + num_requests = len(prompts) + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + parsed_prompts: List[PromptType] = [] + for i in range(num_requests): + item: PromptType + + if prompts is not None: + item = TextPrompt(prompt=prompts[i]) + elif prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) + else: + raise AssertionError + + parsed_prompts.append(item) + + return parsed_prompts + + def _validate_and_add_requests( + self, + prompts: Union[PromptType, Sequence[PromptType]], + params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, + Sequence[PoolingParams]], + lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], + guided_options: Optional[GuidedDecodingRequest] = None, + priority: Optional[List[int]] = None, + ) -> None: + if guided_options is not None: + warnings.warn( + "guided_options_request is deprecated, use " + "SamplingParams.guided_decoding instead", + DeprecationWarning, + stacklevel=2, + ) + + if isinstance(prompts, (str, dict)): + # Convert a single prompt to a list. + prompts = [prompts] + + num_requests = len(prompts) + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + if isinstance(lora_request, + list) and len(lora_request) != num_requests: + raise ValueError("The lengths of prompts and lora_request " + "must be the same.") + + for sp in params if isinstance(params, list) else (params, ): + if isinstance(sp, SamplingParams): + self._add_guided_params(sp, guided_options) + + # We only care about the final output + sp.output_kind = RequestOutputKind.FINAL_ONLY + + # Add requests to the engine. + for i, prompt in enumerate(prompts): + self._add_request( + prompt, + params[i] if isinstance(params, Sequence) else params, + lora_request=lora_request[i] if isinstance( + lora_request, Sequence) else lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority[i] if priority else 0, + ) + + def _add_request( + self, + prompt: PromptType, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request( + request_id, + prompt, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + priority=priority, + ) + + def _add_guided_params( + self, + params: SamplingParams, + guided_options: Optional[GuidedDecodingRequest] = None): + if guided_options is None: + return params + + if params.guided_decoding is not None: + raise ValueError("Cannot set both guided_options_request and" + "params.guided_decoding.") + + params.guided_decoding = GuidedDecodingParams( + json=guided_options.guided_json, + regex=guided_options.guided_regex, + choice=guided_options.guided_choice, + grammar=guided_options.guided_grammar, + json_object=guided_options.guided_json_object, + backend=guided_options.guided_decoding_backend, + whitespace_pattern=guided_options.guided_whitespace_pattern) + return params + + def _run_engine( + self, *, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + + # Run the engine. + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + total_in_toks = 0 + total_out_toks = 0 + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + if isinstance(output, RequestOutput): + # Calculate tokens only for RequestOutput + assert output.prompt_token_ids is not None + total_in_toks += len(output.prompt_token_ids) + in_spd = total_in_toks / pbar.format_dict["elapsed"] + total_out_toks += sum( + len(stp.token_ids) for stp in output.outputs) + out_spd = (total_out_toks / + pbar.format_dict["elapsed"]) + pbar.postfix = ( + f"est. speed input: {in_spd:.2f} toks/s, " + f"output: {out_spd:.2f} toks/s") + pbar.update(1) + + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) + + def _is_encoder_decoder_model(self): + return self.llm_engine.is_encoder_decoder_model() + + def _is_embedding_model(self): + return self.llm_engine.is_embedding_model() diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py new file mode 100644 index 0000000..584ee0d --- /dev/null +++ b/vllm/entrypoints/logger.py @@ -0,0 +1,42 @@ +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams + +logger = init_logger(__name__) + + +class RequestLogger: + + def __init__(self, *, max_log_len: Optional[int]) -> None: + super().__init__() + + self.max_log_len = max_log_len + + def log_inputs( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if prompt is not None: + prompt = prompt[:max_log_len] + + if prompt_token_ids is not None: + prompt_token_ids = prompt_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, lora_request, + prompt_adapter_request) diff --git a/vllm/entrypoints/openai/__init__.py b/vllm/entrypoints/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..84bcaef Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc new file mode 100644 index 0000000..59e7f05 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc new file mode 100644 index 0000000..d67a7cd Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/cli_args.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-310.pyc new file mode 100644 index 0000000..f28c2a5 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/logits_processors.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc new file mode 100644 index 0000000..ddf2de4 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc new file mode 100644 index 0000000..69572f5 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/run_batch.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc new file mode 100644 index 0000000..c9e2f78 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-310.pyc new file mode 100644 index 0000000..1928b66 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/serving_completion.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-310.pyc new file mode 100644 index 0000000..32a12ca Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/serving_embedding.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-310.pyc new file mode 100644 index 0000000..8920d02 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/serving_engine.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-310.pyc b/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-310.pyc new file mode 100644 index 0000000..258edf8 Binary files /dev/null and b/vllm/entrypoints/openai/__pycache__/serving_tokenization.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py new file mode 100644 index 0000000..ae44b26 --- /dev/null +++ b/vllm/entrypoints/openai/api_server.py @@ -0,0 +1,585 @@ +import asyncio +import importlib +import inspect +import multiprocessing +import os +import re +import signal +import socket +import tempfile +from argparse import Namespace +from contextlib import asynccontextmanager +from functools import partial +from http import HTTPStatus +from typing import AsyncIterator, Set + +import uvloop +from fastapi import APIRouter, FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response, StreamingResponse +from starlette.datastructures import State +from starlette.routing import Mount +from typing_extensions import assert_never + +import vllm.envs as envs +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.multiprocessing.engine import run_mp_engine +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + CompletionResponse, + DetokenizeRequest, + DetokenizeResponse, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, + TokenizeRequest, + TokenizeResponse, + UnloadLoraAdapterRequest) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.logger import init_logger +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path +from vllm.version import __version__ as VLLM_VERSION + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +prometheus_multiproc_dir: tempfile.TemporaryDirectory + +# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) +logger = init_logger('vllm.entrypoints.openai.api_server') + +_running_tasks: Set[asyncio.Task] = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + try: + if app.state.log_stats: + engine_client: EngineClient = app.state.engine_client + + async def _force_log(): + while True: + await asyncio.sleep(10.) + await engine_client.do_log_stats() + + task = asyncio.create_task(_force_log()) + _running_tasks.add(task) + task.add_done_callback(_running_tasks.remove) + else: + task = None + try: + yield + finally: + if task is not None: + task.cancel() + finally: + # Ensure app state including engine ref is gc'd + del app.state + + +@asynccontextmanager +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[EngineClient]: + + # Context manager to handle engine_client lifecycle + # Ensures everything is shutdown and cleaned up on error/exit + engine_args = AsyncEngineArgs.from_cli_args(args) + + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[EngineClient]: + """ + Create EngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + + # Fall back + # TODO: fill out feature matrix. + if (MQLLMEngineClient.is_unsupported_config(engine_args) + or disable_frontend_multiprocessing): + engine_config = engine_args.create_engine_config() + uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config), + "uses_ray", False) + + build_engine = partial(AsyncLLMEngine.from_engine_args, + engine_args=engine_args, + engine_config=engine_config, + usage_context=UsageContext.OPENAI_API_SERVER) + if uses_ray: + # Must run in main thread with ray for its signal handlers to work + engine_client = build_engine() + else: + engine_client = await asyncio.get_running_loop().run_in_executor( + None, build_engine) + + yield engine_client + return + + # Otherwise, use the multiprocessing AsyncLLMEngine. + else: + if "PROMETHEUS_MULTIPROC_DIR" not in os.environ: + # Make TemporaryDirectory for prometheus multiprocessing + # Note: global TemporaryDirectory will be automatically + # cleaned up upon exit. + global prometheus_multiproc_dir + prometheus_multiproc_dir = tempfile.TemporaryDirectory() + os.environ[ + "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name + else: + logger.warning( + "Found PROMETHEUS_MULTIPROC_DIR was set by user. " + "This directory must be wiped between vLLM runs or " + "you will find inaccurate metrics. Unset the variable " + "and vLLM will properly handle cleanup.") + + # Select random path for IPC. + ipc_path = get_open_zmq_ipc_path() + logger.info("Multiprocessing frontend to use %s for IPC Path.", + ipc_path) + + # Start RPCServer in separate process (holds the LLMEngine). + # the current process might have CUDA context, + # so we need to spawn a new process + context = multiprocessing.get_context("spawn") + + engine_process = context.Process(target=run_mp_engine, + args=(engine_args, + UsageContext.OPENAI_API_SERVER, + ipc_path)) + engine_process.start() + logger.info("Started engine process with PID %d", engine_process.pid) + + # Build RPCClient, which conforms to EngineClient Protocol. + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + engine_config = engine_args.create_engine_config() + mp_engine_client = MQLLMEngineClient(ipc_path, engine_config) + + try: + while True: + try: + await mp_engine_client.setup() + break + except TimeoutError: + if not engine_process.is_alive(): + raise RuntimeError( + "Engine process failed to start") from None + + yield mp_engine_client # type: ignore[misc] + finally: + # Ensure rpc server process was terminated + engine_process.terminate() + + # Close all open connections to the backend + mp_engine_client.close() + + # Wait for engine process to join + engine_process.join(4) + if engine_process.exitcode is None: + # Kill if taking longer than 5 seconds to stop + engine_process.kill() + + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import multiprocess + multiprocess.mark_process_dead(engine_process.pid) + + +router = APIRouter() + + +def mount_metrics(app: FastAPI): + # Lazy import for prometheus multiprocessing. + # We need to set PROMETHEUS_MULTIPROC_DIR environment variable + # before prometheus_client is imported. + # See https://prometheus.github.io/client_python/multiprocess/ + from prometheus_client import (CollectorRegistry, make_asgi_app, + multiprocess) + + prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None) + if prometheus_multiproc_dir_path is not None: + logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR", + prometheus_multiproc_dir_path) + registry = CollectorRegistry() + multiprocess.MultiProcessCollector(registry) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + else: + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app()) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) + + +def chat(request: Request) -> OpenAIServingChat: + return request.app.state.openai_serving_chat + + +def completion(request: Request) -> OpenAIServingCompletion: + return request.app.state.openai_serving_completion + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def embedding(request: Request) -> OpenAIServingEmbedding: + return request.app.state.openai_serving_embedding + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health") +async def health(raw_request: Request) -> Response: + """Health check.""" + await engine_client(raw_request).check_health() + return Response(status_code=200) + + +@router.post("/tokenize") +async def tokenize(request: TokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_tokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post("/detokenize") +async def detokenize(request: DetokenizeRequest, raw_request: Request): + generator = await tokenization(raw_request).create_detokenize(request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.get("/v1/models") +async def show_available_models(raw_request: Request): + models = await completion(raw_request).show_available_models() + return JSONResponse(content=models.model_dump()) + + +@router.get("/version") +async def show_version(): + ver = {"version": VLLM_VERSION} + return JSONResponse(content=ver) + + +@router.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + + generator = await chat(raw_request).create_chat_completion( + request, raw_request) + + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + + elif isinstance(generator, ChatCompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + generator = await completion(raw_request).create_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, CompletionResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +@router.post("/v1/embeddings") +async def create_embedding(request: EmbeddingRequest, raw_request: Request): + generator = await embedding(raw_request).create_embedding( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, EmbeddingResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!") + + @router.post("/start_profile") + async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + @router.post("/stop_profile") + async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await completion(raw_request).load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest, + raw_request: Request): + response = await chat(raw_request).unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await completion(raw_request).unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + +def build_app(args: Namespace) -> FastAPI: + if args.disable_fastapi_docs: + app = FastAPI(openapi_url=None, + docs_url=None, + redoc_url=None, + lifespan=lifespan) + else: + app = FastAPI(lifespan=lifespan) + app.include_router(router) + app.root_path = args.root_path + + mount_metrics(app) + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(_, exc): + chat = app.state.openai_serving_chat + err = chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), + status_code=HTTPStatus.BAD_REQUEST) + + if token := envs.VLLM_API_KEY or args.api_key: + + @app.middleware("http") + async def authentication(request: Request, call_next): + root_path = "" if args.root_path is None else args.root_path + if request.method == "OPTIONS": + return await call_next(request) + if not request.url.path.startswith(f"{root_path}/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError(f"Invalid middleware {middleware}. " + f"Must be a function or a class.") + + return app + + +def init_app_state( + engine_client: EngineClient, + model_config: ModelConfig, + state: State, + args: Namespace, +) -> None: + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + state.engine_client = engine_client + state.log_stats = not args.disable_log_stats + + state.openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + base_model_paths, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser) + state.openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + base_model_paths, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + ) + state.openai_serving_embedding = OpenAIServingEmbedding( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger, + ) + state.openai_serving_tokenization = OpenAIServingTokenization( + engine_client, + model_config, + base_model_paths, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=args.chat_template, + ) + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + + # workaround to make sure that we bind the port before the engine is set up. + # This avoids race conditions with ray. + # see https://github.com/vllm-project/vllm/issues/8204 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + fd=sock.fileno(), + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + +if __name__ == "__main__": + # NOTE(simon): + # This section should be in sync with vllm/scripts.py for CLI entrypoints. + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + + uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py new file mode 100644 index 0000000..a089985 --- /dev/null +++ b/vllm/entrypoints/openai/cli_args.py @@ -0,0 +1,252 @@ +""" +This file contains the command line arguments for the vLLM's +OpenAI-compatible server. It is kept in a separate file for documentation +purposes. +""" + +import argparse +import json +import ssl +from typing import List, Optional, Sequence, Union + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.entrypoints.chat_utils import validate_chat_template +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager +from vllm.utils import FlexibleArgumentParser + + +class LoRAParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: List[LoRAModulePath] = [] + for item in values: + if item in [None, '']: # Skip if item is None or empty string + continue + if '=' in item and ',' not in item: # Old format: name=path + name, path = item.split('=') + lora_list.append(LoRAModulePath(name, path)) + else: # Assume JSON format + try: + lora_dict = json.loads(item) + lora = LoRAModulePath(**lora_dict) + lora_list.append(lora) + except json.JSONDecodeError: + parser.error( + f"Invalid JSON format for --lora-modules: {item}") + except TypeError as e: + parser.error( + f"Invalid fields for --lora-modules: {item} - {str(e)}" + ) + setattr(namespace, self.dest, lora_list) + + +class PromptAdapterParserAction(argparse.Action): + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: List[PromptAdapterPath] = [] + for item in values: + name, path = item.split('=') + adapter_list.append(PromptAdapterPath(name, path)) + setattr(namespace, self.dest, adapter_list) + + +def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument("--host", + type=nullable_str, + default=None, + help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument( + "--uvicorn-log-level", + type=str, + default="info", + choices=['debug', 'info', 'warning', 'error', 'critical', 'trace'], + help="log level for uvicorn") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument("--api-key", + type=nullable_str, + default=None, + help="If provided, the server will require this key " + "to be presented in the header.") + parser.add_argument( + "--lora-modules", + type=nullable_str, + default=None, + nargs='+', + action=LoRAParserAction, + help="LoRA module configurations in either 'name=path' format" + "or JSON format. " + "Example (old format): 'name=path' " + "Example (new format): " + "'{\"name\": \"name\", \"local_path\": \"path\", " + "\"base_model_name\": \"id\"}'") + parser.add_argument( + "--prompt-adapters", + type=nullable_str, + default=None, + nargs='+', + action=PromptAdapterParserAction, + help="Prompt adapter configurations in the format name=path. " + "Multiple adapters can be specified.") + parser.add_argument("--chat-template", + type=nullable_str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + parser.add_argument("--ssl-keyfile", + type=nullable_str, + default=None, + help="The file path to the SSL key file") + parser.add_argument("--ssl-certfile", + type=nullable_str, + default=None, + help="The file path to the SSL cert file") + parser.add_argument("--ssl-ca-certs", + type=nullable_str, + default=None, + help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)" + ) + parser.add_argument( + "--root-path", + type=nullable_str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--middleware", + type=nullable_str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server " + "using @app.middleware('http'). " + "If a class is provided, vLLM will add it to the server " + "using app.add_middleware(). ") + parser.add_argument( + "--return-tokens-as-token-ids", + action="store_true", + help="When --max-logprobs is specified, represents single tokens as " + "strings of the form 'token_id:{token_id}' so that tokens that " + "are not JSON-encodable can be identified.") + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + help="If specified, will run the OpenAI frontend server in the same " + "process as the model serving engine.") + + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") + + valid_tool_parsers = ToolParserManager.tool_parsers.keys() + parser.add_argument( + "--tool-call-parser", + type=str, + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice.") + + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in --tool-call-parser.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument( + "--disable-fastapi-docs", + action='store_true', + default=False, + help="Disable FastAPI's OpenAPI schema, Swagger UI, and ReDoc endpoint" + ) + + return parser + + +def validate_parsed_serve_args(args: argparse.Namespace): + """Quick checks for model serve args that raise prior to loading.""" + if hasattr(args, "subparser") and args.subparser != "serve": + return + + # Ensure that the chat template is valid; raises if it likely isn't + validate_chat_template(args.chat_template) + + # Enable auto tool needs a tool call parser to be valid + if args.enable_auto_tool_choice and not args.tool_call_parser: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + + +def create_parser_for_docs() -> FlexibleArgumentParser: + parser_for_docs = FlexibleArgumentParser( + prog="-m vllm.entrypoints.openai.api_server") + return make_arg_parser(parser_for_docs) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py new file mode 100644 index 0000000..7913f87 --- /dev/null +++ b/vllm/entrypoints/openai/logits_processors.py @@ -0,0 +1,86 @@ +from functools import lru_cache, partial +from typing import Dict, FrozenSet, Iterable, List, Optional, Union + +import torch + +from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class AllowedTokenIdsLogitsProcessor: + """Logits processor for constraining generated tokens to a + specific set of token ids.""" + + def __init__(self, allowed_ids: Iterable[int]): + self.allowed_ids: Optional[List[int]] = list(allowed_ids) + self.mask: Optional[torch.Tensor] = None + + def __call__(self, token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + if self.mask is None: + self.mask = torch.ones((logits.shape[-1], ), + dtype=torch.bool, + device=logits.device) + self.mask[self.allowed_ids] = False + self.allowed_ids = None + logits.masked_fill_(self.mask, float("-inf")) + return logits + + +@lru_cache(maxsize=32) +def _get_allowed_token_ids_logits_processor( + allowed_token_ids: FrozenSet[int], + vocab_size: int, +) -> LogitsProcessor: + if not allowed_token_ids: + raise ValueError("Empty allowed_token_ids provided") + if not all(0 <= tid < vocab_size for tid in allowed_token_ids): + raise ValueError("allowed_token_ids contains " + "out-of-vocab token id") + return AllowedTokenIdsLogitsProcessor(allowed_token_ids) + + +def logit_bias_logits_processor( + logit_bias: Dict[int, float], + token_ids: List[int], + logits: torch.Tensor, +) -> torch.Tensor: + for token_id, bias in logit_bias.items(): + logits[token_id] += bias + return logits + + +def get_logits_processors( + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], + allowed_token_ids: Optional[List[int]], + tokenizer: AnyTokenizer, +) -> List[LogitsProcessor]: + logits_processors: List[LogitsProcessor] = [] + if logit_bias: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: Dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + + # Check if token_id is within the vocab size + for token_id, bias in clamped_logit_bias.items(): + if token_id < 0 or token_id >= tokenizer.vocab_size: + raise ValueError(f"token_id {token_id} in logit_bias contains " + "out-of-vocab token id") + + logits_processors.append( + partial(logit_bias_logits_processor, clamped_logit_bias)) + + if allowed_token_ids is not None: + logits_processors.append( + _get_allowed_token_ids_logits_processor( + frozenset(allowed_token_ids), tokenizer.vocab_size)) + + return logits_processors diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py new file mode 100644 index 0000000..6f1135f --- /dev/null +++ b/vllm/entrypoints/openai/protocol.py @@ -0,0 +1,992 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from argparse import Namespace +from typing import Any, Dict, List, Literal, Optional, Union + +import torch +from openai.types.chat import ChatCompletionContentPartParam +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Annotated, Required, TypedDict + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, + RequestOutputKind, SamplingParams) +from vllm.sequence import Logprob +from vllm.utils import random_uuid + +# torch is mocked during docs generation, +# so we have to provide the values as literals +_MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) +_LONG_INFO: Union["torch.iinfo", Namespace] + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if isinstance(torch, _MockModule): + _LONG_INFO = _MOCK_LONG_INFO + else: + _LONG_INFO = torch.iinfo(torch.long) +except ModuleNotFoundError: + _LONG_INFO = torch.iinfo(torch.long) + +assert _LONG_INFO.min == _MOCK_LONG_INFO.min +assert _LONG_INFO.max == _MOCK_LONG_INFO.max + + +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + + +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + max_model_len: Optional[int] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class RequestResponseMetadata(BaseModel): + request_id: str + final_usage_info: Optional[UsageInfo] = None + + +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_schema", "json_object" or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None + + +class StreamOptions(OpenAIBaseModel): + include_usage: Optional[bool] = True + continuous_usage_stats: Optional[bool] = True + + +class FunctionDefinition(OpenAIBaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class ChatCompletionToolsParam(OpenAIBaseModel): + type: Literal["function"] = "function" + function: FunctionDefinition + + +class ChatCompletionNamedFunction(OpenAIBaseModel): + name: str + + +class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): + function: ChatCompletionNamedFunction + type: Literal["function"] = "function" + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = 0 + max_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + tools: Optional[List[ChatCompletionToolsParam]] = None + tool_choice: Optional[Union[Literal["none"], Literal["auto"], + ChatCompletionNamedToolChoiceParam]] = "none" + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + prompt_logprobs: Optional[int] = None + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: bool = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: bool = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + continue_final_message: bool = Field( + default=False, + description= + ("If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + "This allows you to \"prefill\" part of the model's response for it. " + "Cannot be used at the same time as `add_generation_prompt`."), + ) + add_special_tokens: bool = Field( + default=False, + description=( + "If true, special tokens (e.g. BOS) will be added to the prompt " + "on top of what is added by the chat template. " + "For most models, the chat template takes care of adding the " + "special tokens so this should be set to false (as is the " + "default)."), + ) + documents: Optional[List[Dict[str, str]]] = Field( + default=None, + description= + ("A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + "\"title\" and \"text\" keys."), + ) + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the tokenizer " + "does not define one."), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=("Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-chat-completion-extra-params + + def to_beam_search_params(self, + default_max_tokens: int) -> BeamSearchParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + n = self.n if self.n is not None else 1 + temperature = self.temperature if self.temperature is not None else 0.0 + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + ) + + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self._get_guided_json_from_tool() or self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=prompt_logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens, + min_tokens=self.min_tokens, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias) + + def _get_guided_json_from_tool( + self) -> Optional[Union[str, dict, BaseModel]]: + # user has chosen to not use any tool + if self.tool_choice == "none" or self.tools is None: + return None + + # user has chosen to use a named tool + if type(self.tool_choice) is ChatCompletionNamedToolChoiceParam: + tool_name = self.tool_choice.function.name + tools = {tool.function.name: tool.function for tool in self.tools} + if tool_name not in tools: + raise ValueError( + f"Tool '{tool_name}' has not been passed in `tools`.") + tool = tools[tool_name] + return tool.parameters + + return None + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + # you can only use one kind of guided decoding + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + # you can only either use guided decoding or tools, not both + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): + raise ValueError( + "You can only either use guided decoding or tools, not both.") + return data + + @model_validator(mode="before") + @classmethod + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and data.get("tools"): + data["tool_choice"] = "auto" + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present + if "tools" not in data or data["tools"] is None: + raise ValueError( + "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): + raise ValueError( + "`tool_choice` must either be a named tool or \"auto\". " + "`tool_choice=\"none\" is not supported.") + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"]["function"] + if not specified_function: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function["name"] + if not specified_function_name: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") + return data + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + stream_options: Optional[StreamOptions] = None + suffix: Optional[str] = None + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + allowed_token_ids: Optional[List[int]] = None + prompt_logprobs: Optional[int] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description="If specified, the output will follow the JSON schema.", + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-completion-extra-params + + def to_beam_search_params(self, + default_max_tokens: int) -> BeamSearchParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + n = self.n if self.n is not None else 1 + temperature = self.temperature if self.temperature is not None else 0.0 + + return BeamSearchParams( + beam_width=n, + max_tokens=max_tokens, + ignore_eos=self.ignore_eos, + temperature=temperature, + length_penalty=self.length_penalty, + ) + + def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + + echo_without_generation = self.echo and self.max_tokens == 0 + + guided_json_object = None + if (self.response_format is not None + and self.response_format.type == "json_object"): + guided_json_object = True + + guided_decoding = GuidedDecodingParams.from_optional( + json=self.guided_json, + regex=self.guided_regex, + choice=self.guided_choice, + grammar=self.guided_grammar, + json_object=guided_json_object, + backend=self.guided_decoding_backend, + whitespace_pattern=self.guided_whitespace_pattern) + + return SamplingParams.from_optional( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, + ignore_eos=self.ignore_eos, + max_tokens=max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, + prompt_logprobs=prompt_logprobs, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + truncate_prompt_tokens=self.truncate_prompt_tokens, + output_kind=RequestOutputKind.DELTA if self.stream \ + else RequestOutputKind.FINAL_ONLY, + guided_decoding=guided_decoding, + logit_bias=self.logit_bias, + allowed_token_ids=self.allowed_token_ids) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + + return data + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data + + +class EmbeddingRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/embeddings + model: str + input: Union[List[int], List[List[int]], str, List[str]] + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = None + user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-embedding-pooling-params + additional_data: Optional[Any] = None + + # doc: end-embedding-pooling-params + + # doc: begin-embedding-extra-params + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + # doc: end-embedding-extra-params + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + +class CompletionLogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: List[Optional[Dict[str, + float]]] = Field(default_factory=list) + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[CompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class EmbeddingResponseData(OpenAIBaseModel): + index: int + object: str = "embedding" + embedding: Union[List[float], str] + + +class EmbeddingResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[EmbeddingResponseData] + usage: UsageInfo + + +class FunctionCall(OpenAIBaseModel): + name: str + arguments: str + + +class ToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + function: FunctionCall + + +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + +class ChatMessage(OpenAIBaseModel): + role: str + content: Optional[str] = None + tool_calls: List[ToolCall] = Field(default_factory=list) + + +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[ChatCompletionLogProbs] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion"] = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + prompt_logprobs: Optional[List[Optional[Dict[int, Logprob]]]] = None + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + tool_calls: List[DeltaToolCall] = Field(default_factory=list) + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class BatchRequestInput(OpenAIBaseModel): + """ + The per-line object of the batch input file. + + NOTE: Currently only the `/v1/chat/completions` endpoint is supported. + """ + + # A developer-provided per-request id that will be used to match outputs to + # inputs. Must be unique for each request in a batch. + custom_id: str + + # The HTTP method to be used for the request. Currently only POST is + # supported. + method: str + + # The OpenAI API relative URL to be used for the request. Currently + # /v1/chat/completions is supported. + url: str + + # The parameters of the request. + body: Union[ChatCompletionRequest, EmbeddingRequest] + + +class BatchResponseData(OpenAIBaseModel): + # HTTP status code of the response. + status_code: int = 200 + + # An unique identifier for the API request. + request_id: str + + # The body of the response. + body: Optional[Union[ChatCompletionResponse, EmbeddingResponse]] = None + + +class BatchRequestOutput(OpenAIBaseModel): + """ + The per-line object of the batch output and error files + """ + + id: str + + # A developer-provided per-request id that will be used to match outputs to + # inputs. + custom_id: str + + response: Optional[BatchResponseData] + + # For requests that failed with a non-HTTP error, this will contain more + # information on the cause of the failure. + error: Optional[Any] + + +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + add_special_tokens: bool = Field(default=True) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + + add_generation_prompt: bool = Field(default=True) + continue_final_message: bool = Field(default=False) + add_special_tokens: bool = Field(default=False) + + @model_validator(mode="before") + @classmethod + def check_generation_prompt(cls, data): + if data.get("continue_final_message") and data.get( + "add_generation_prompt"): + raise ValueError("Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True.") + return data + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] + + +class TokenizeResponse(OpenAIBaseModel): + count: int + max_model_len: int + tokens: List[int] + + +class DetokenizeRequest(OpenAIBaseModel): + model: str + tokens: List[int] + + +class DetokenizeResponse(OpenAIBaseModel): + prompt: str + + +class LoadLoraAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoraAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py new file mode 100644 index 0000000..f5249a0 --- /dev/null +++ b/vllm/entrypoints/openai/run_batch.py @@ -0,0 +1,285 @@ +import asyncio +from http import HTTPStatus +from io import StringIO +from typing import Awaitable, Callable, List, Optional + +import aiohttp +import torch +from prometheus_client import start_http_server +from tqdm import tqdm + +from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger, logger +# yapf: disable +from vllm.entrypoints.openai.protocol import (BatchRequestInput, + BatchRequestOutput, + BatchResponseData, + ChatCompletionResponse, + EmbeddingResponse, ErrorResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.usage.usage_lib import UsageContext +from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION + + +def parse_args(): + parser = FlexibleArgumentParser( + description="vLLM OpenAI-Compatible batch runner.") + parser.add_argument( + "-i", + "--input-file", + required=True, + type=str, + help= + "The path or url to a single input file. Currently supports local file " + "paths, or the http protocol (http or https). If a URL is specified, " + "the file should be available via HTTP GET.") + parser.add_argument( + "-o", + "--output-file", + required=True, + type=str, + help="The path or url to a single output file. Currently supports " + "local file paths, or web (http or https) urls. If a URL is specified," + " the file should be available via HTTP PUT.") + parser.add_argument("--response-role", + type=nullable_str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=True`.") + + parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + + parser.add_argument("--enable-metrics", + action="store_true", + help="Enable Prometheus metrics") + parser.add_argument( + "--url", + type=str, + default="0.0.0.0", + help="URL to the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="Port number for the Prometheus metrics server " + "(only needed if enable-metrics is set).", + ) + + return parser.parse_args() + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + +async def read_file(path_or_url: str) -> str: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.get(path_or_url) as resp: + return await resp.text() + else: + with open(path_or_url, "r", encoding="utf-8") as f: + return f.read() + + +async def write_file(path_or_url: str, data: str) -> None: + if path_or_url.startswith("http://") or path_or_url.startswith("https://"): + async with aiohttp.ClientSession() as session, \ + session.put(path_or_url, data=data.encode("utf-8")): + pass + else: + # We should make this async, but as long as this is always run as a + # standalone program, blocking the event loop won't effect performance + # in this particular case. + with open(path_or_url, "w", encoding="utf-8") as f: + f.write(data) + + +def make_error_request_output(request: BatchRequestInput, + error_msg: str) -> BatchRequestOutput: + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=HTTPStatus.BAD_REQUEST, + request_id=f"vllm-batch-{random_uuid()}", + ), + error=error_msg, + ) + return batch_output + + +async def make_async_error_request_output( + request: BatchRequestInput, error_msg: str) -> BatchRequestOutput: + return make_error_request_output(request, error_msg) + + +async def run_request(serving_engine_func: Callable, + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: + response = await serving_engine_func(request.body) + + if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + body=response, request_id=f"vllm-batch-{random_uuid()}"), + error=None, + ) + elif isinstance(response, ErrorResponse): + batch_output = BatchRequestOutput( + id=f"vllm-{random_uuid()}", + custom_id=request.custom_id, + response=BatchResponseData( + status_code=response.code, + request_id=f"vllm-batch-{random_uuid()}"), + error=response, + ) + else: + batch_output = make_error_request_output( + request, error_msg="Request must not be sent in stream mode") + + tracker.completed() + return batch_output + + +async def main(args): + if args.served_model_name is not None: + served_model_names = args.served_model_name + else: + served_model_names = [args.model] + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER) + + model_config = await engine.get_model_config() + base_model_paths = [ + BaseModelPath(name=name, model_path=args.model) + for name in served_model_names + ] + + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + + # Create the openai serving objects. + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + base_model_paths, + args.response_role, + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger, + chat_template=None, + ) + openai_serving_embedding = OpenAIServingEmbedding( + engine, + model_config, + base_model_paths, + request_logger=request_logger, + ) + + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + + # Submit all requests in the file to the engine "concurrently". + response_futures: List[Awaitable[BatchRequestOutput]] = [] + for request_json in (await read_file(args.input_file)).strip().split("\n"): + # Skip empty lines. + request_json = request_json.strip() + if not request_json: + continue + + request = BatchRequestInput.model_validate_json(request_json) + + # Determine the type of request and run it. + if request.url == "/v1/chat/completions": + response_futures.append( + run_request(openai_serving_chat.create_chat_completion, + request, tracker)) + tracker.submitted() + elif request.url == "/v1/embeddings": + response_futures.append( + run_request(openai_serving_embedding.create_embedding, request, + tracker)) + tracker.submitted() + else: + response_futures.append( + make_async_error_request_output( + request, + error_msg="Only /v1/chat/completions and " + "/v1/embeddings are supported in the batch endpoint.", + )) + + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) + + output_buffer = StringIO() + for response in responses: + print(response.model_dump_json(), file=output_buffer) + + output_buffer.seek(0) + await write_file(args.output_file, output_buffer.read().strip()) + + +if __name__ == "__main__": + args = parse_args() + + logger.info("vLLM batch processing API version %s", VLLM_VERSION) + logger.info("args: %s", args) + + # Start the Prometheus metrics server. LLMEngine uses the Prometheus client + # to publish metrics at the /metrics endpoint. + if args.enable_metrics: + logger.info("Prometheus metrics enabled") + start_http_server(port=args.port, addr=args.url) + else: + logger.info("Prometheus metrics disabled") + + asyncio.run(main(args)) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000..4931195 --- /dev/null +++ b/vllm/entrypoints/openai/serving_chat.py @@ -0,0 +1,891 @@ +import asyncio +import json +import time +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) +from typing import Sequence as GenericSequence +from typing import Union + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (ConversationMessage, + apply_hf_chat_template, + apply_mistral_chat_template, + load_chat_template, + parse_chat_messages_futures) +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, RequestResponseMetadata, + ToolCall, UsageInfo) +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, + OpenAIServing, + PromptAdapterPath, + TextTokensPrompt) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager +from vllm.inputs import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import iterate_with_cancellation, random_uuid + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__(self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None): + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + prompt_adapters=prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + self.response_role = response_role + self.use_tool_use_model_template = False + self.chat_template = load_chat_template(chat_template) + + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + try: + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: + raise TypeError("Error: --enable-auto-tool-choice requires " + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e + + async def create_chat_completion( + self, + request: ChatCompletionRequest, + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI + ChatCompletion API. + + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + model_config = self.model_config + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + conversation, mm_data_future = parse_chat_messages_futures( + request.messages, model_config, tokenizer) + + tool_dicts = None if request.tools is None else [ + tool.model_dump() for tool in request.tools + ] + + prompt: Union[str, List[int]] + is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer) + if is_mistral_tokenizer: + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) + except Exception as e: + logger.exception("Error in applying chat template from request") + return self.create_error_response(str(e)) + + try: + mm_data = await mm_data_future + except Exception as e: + logger.exception("Error in loading multi-modal data") + return self.create_error_response(str(e)) + + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + if not is_mistral_tokenizer and request.tool_choice == "auto" and not ( + self.enable_auto_tools and self.tool_parser is not None): + # for hf tokenizers, "auto" tools requires + # --enable-auto-tool-choice and --tool-call-parser + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set") + + request_id = f"chat-{random_uuid()}" + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + try: + if self.enable_auto_tools and self.tool_parser: + request = self.tool_parser(tokenizer).adjust_request( + request=request) + + if isinstance(prompt, str): + prompt_inputs = self._tokenize_prompt_input( + request, + tokenizer, + prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + assert isinstance(prompt, list) and isinstance( + prompt[0], int + ), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(prompt), prompt_token_ids=prompt) + + assert prompt_inputs is not None + + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + prompt_inputs["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens) + else: + sampling_params = request.to_sampling_params( + default_max_tokens) + + self._log_inputs(request_id, + prompt_inputs, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + engine_inputs = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) + if mm_data is not None: + engine_inputs["multi_modal_data"] = mm_data + + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) + trace_headers = None + if is_tracing_enabled and raw_request: + trace_headers = extract_trace_headers(raw_request.headers) + if (not is_tracing_enabled and raw_request + and contains_trace_headers(raw_request.headers)): + log_tracing_disabled_warning() + + if isinstance(sampling_params, BeamSearchParams): + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." + result_generator = self.engine_client.beam_search( + engine_inputs['prompt_token_ids'], + request_id, + sampling_params, + ) + else: + result_generator = self.engine_client.generate( + engine_inputs, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + priority=request.priority, + ) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + if raw_request: + result_generator = iterate_with_cancellation( + result_generator, raw_request.is_disconnected) + + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id, conversation, tokenizer, + request_metadata) + + try: + return await self.chat_completion_full_generator( + request, result_generator, request_id, conversation, tokenizer, + request_metadata) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + return request.messages[-1]["role"] + + async def chat_completion_stream_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + model_name = self.base_model_paths[0].name + created_time = int(time.time()) + chunk_object_type: Final = "chat.completion.chunk" + first_iteration = True + + # Send response for each token for each request.n (index) + num_choices = 1 if request.n is None else request.n + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + num_prompt_tokens = 0 + + if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): + tool_choice_function_name = request.tool_choice.function.name + else: + tool_choice_function_name = None + + # Determine whether tools are in use with "auto" tool choice + tool_choice_auto = ( + not tool_choice_function_name + and self._should_stream_with_auto_tool_parsing(request)) + + all_previous_token_ids: Optional[List[List[int]]] + if tool_choice_auto: + # These are only required in "auto" tool choice case + previous_texts = [""] * num_choices + all_previous_token_ids = [[]] * num_choices + else: + previous_texts, all_previous_token_ids = None, None + + # Prepare the tool parser if it's needed + try: + if tool_choice_auto and self.tool_parser: + tool_parsers: List[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) + ] * num_choices + else: + tool_parsers = [None] * num_choices + except RuntimeError as e: + logger.error("Error in tool parser creation: %s", e) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + + try: + async for res in result_generator: + if res.prompt_token_ids is not None: + num_prompt_tokens = len(res.prompt_token_ids) + if res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(res.encoder_prompt_token_ids) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + if first_iteration: + # Send first response for each request.n (index) with + # the role + role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request + for i in range(num_choices): + tool_parser = tool_parsers[i] + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + role=role, + content="", + ), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # if usage should be included + if (request.stream_options + and request.stream_options.include_usage): + # if continuous usage stats are requested, add it + if request.stream_options.continuous_usage_stats: + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + chunk.usage = usage + # otherwise don't + else: + chunk.usage = None + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the + # last message + if request.echo or request.continue_final_message: + last_msg_content: str = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + if last_msg_content: + for i in range(num_choices): + choice_data = ( + ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage( + content=last_msg_content), + logprobs=None, + finish_reason=None)) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options and + request.stream_options.include_usage): + if (request.stream_options. + continuous_usage_stats): + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=0, + total_tokens=num_prompt_tokens) + chunk.usage = usage + else: + chunk.usage = None + + data = chunk.model_dump_json( + exclude_unset=True) + yield f"data: {data}\n\n" + first_iteration = False + + for output in res.outputs: + i = output.index + tool_parser = tool_parsers[i] + + if finish_reason_sent[i]: + continue + + if request.logprobs and request.top_logprobs is not None: + assert output.logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_chat_logprobs( + token_ids=output.token_ids, + top_logprobs=output.logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.top_logprobs, + ) + else: + logprobs = None + + delta_text = output.text + delta_message: Optional[DeltaMessage] + + # handle streaming deltas for tools with named tool_choice + if tool_choice_function_name: + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(function=DeltaFunctionCall( + name=tool_choice_function_name, + arguments=delta_text), + index=i) + ]) + + # handle streaming deltas for tools with "auto" tool choice + elif tool_choice_auto: + assert previous_texts is not None + assert all_previous_token_ids is not None + assert tool_parser is not None + #TODO optimize manipulation of these lists + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + previous_token_ids=previous_token_ids, + current_token_ids=current_token_ids, + delta_token_ids=output.token_ids, + request=request)) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids + + # handle streaming just a content delta + else: + delta_message = DeltaMessage(content=delta_text) + + # set the previous values for the next iteration + previous_num_tokens[i] += len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + + if output.finish_reason is None: + # Send token-by-token response for each request.n + + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if (request.stream_options + and request.stream_options.include_usage): + if request.stream_options.continuous_usage_stats: + completion_tokens = len(output.token_ids) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens, + ) + chunk.usage = usage + else: + chunk.usage = None + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # if the model is finished generating + else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + auto_tools_called = False + if tool_parser: + auto_tools_called = len( + tool_parser.prev_tool_call_arr) > 0 + index = len(tool_parser.prev_tool_call_arr + ) - 1 if auto_tools_called else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {})) + + # get what we've streamed so far for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + + # Send the finish response for each request.n only once + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + logprobs=logprobs, + finish_reason=output.finish_reason + if not auto_tools_called else "tool_calls", + stop_reason=output.stop_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if (request.stream_options + and request.stream_options.include_usage): + if request.stream_options.continuous_usage_stats: + completion_tokens = len(output.token_ids) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens, + ) + chunk.usage = usage + else: + chunk.usage = None + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage + if (request.stream_options + and request.stream_options.include_usage): + completion_tokens = previous_num_tokens[i] + final_usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + final_usage_chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + num_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_completion_tokens, + total_tokens=num_prompt_tokens + num_completion_tokens) + + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + logger.error("error in chat completion stream generator: %s", e) + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + conversation: List[ConversationMessage], + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> Union[ErrorResponse, ChatCompletionResponse]: + + model_name = self.base_model_paths[0].name + created_time = int(time.time()) + final_res: Optional[RequestOutput] = None + + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + + assert final_res is not None + + choices: List[ChatCompletionResponseChoice] = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs and request.top_logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_chat_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.top_logprobs, + tokenizer=tokenizer, + ) + else: + logprobs = None + + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = False + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if (not self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, content=output.text) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( + request.tool_choice) is ChatCompletionNamedToolChoiceParam: + + message = ChatMessage( + role=role, + content="", + tool_calls=[ + ToolCall(function=FunctionCall( + name=request.tool_choice.function.name, + arguments=output.text)) + ]) + + # if the request doesn't use tool choice + # OR specifies to not use a tool + elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, content=output.text) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + try: + tool_parser = self.tool_parser(tokenizer) + except RuntimeError as e: + logger.error("Error in tool parser creation: %s", e) + return self.create_error_response(str(e)) + + tool_call_info = tool_parser.extract_tool_calls( + output.text, request=request) + # In the OpenAI API the finish_reason is "tools_called" + # if the tool choice is auto and the model produced a tool + # call. The same is not true for named function calls + auto_tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, content=output.text) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") + message = ChatMessage(role=role, content=output.text) + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=message, + logprobs=logprobs, + finish_reason="tool_calls" if auto_tools_called else + output.finish_reason if output.finish_reason else "stop", + stop_reason=output.stop_reason) + choices.append(choice_data) + + if request.echo or request.continue_final_message: + last_msg_content = "" + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" + + for choice in choices: + full_message = last_msg_content + (choice.message.content + or "") + choice.message.content = full_message + + assert final_res.prompt_token_ids is not None + num_prompt_tokens = len(final_res.prompt_token_ids) + if final_res.encoder_prompt_token_ids is not None: + num_prompt_tokens += len(final_res.encoder_prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + request_metadata.final_usage_info = usage + + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + prompt_logprobs=final_res.prompt_logprobs, + ) + + return response + + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], + tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb(token=(token := self._get_decoded_token( + p[1], + p[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids)), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + token.encode("utf-8", errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + tokenizer: AnyTokenizer, + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + logprobs_content: List[ChatCompletionLogProbsContent] = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=token, + bytes=list(token.encode("utf-8", errors="replace")), + )) + else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self._get_decoded_token( + step_token, + token_id, + tokenizer, + self.return_tokens_as_token_ids, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + ), + )) + + return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + output.finish_reason is not None + and self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + ) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py new file mode 100644 index 0000000..077312d --- /dev/null +++ b/vllm/entrypoints/openai/serving_completion.py @@ -0,0 +1,554 @@ +import asyncio +import time +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, + Optional) +from typing import Sequence as GenericSequence +from typing import Tuple, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.multiprocessing.client import MQLLMEngineClient +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + ErrorResponse, + RequestResponseMetadata, + UsageInfo) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, + OpenAIServing, + PromptAdapterPath) +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.tracing import (contains_trace_headers, extract_trace_headers, + log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] +TypeTopLogProbs = List[Optional[Dict[int, float]]] +TypeCreateLogProbsFn = Callable[ + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + prompt_adapters=prompt_adapters, + request_logger=request_logger, + return_tokens_as_token_ids=return_tokens_as_token_ids) + + async def create_completion( + self, + request: CompletionRequest, + raw_request: Request, + ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following feature: + - suffix (the language models we currently support do not support + suffix) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # If the engine is dead, raise the engine's DEAD_ERROR. + # This is required for the streaming case, where we return a + # success status before we actually start generating text :). + if self.engine_client.errored: + raise self.engine_client.dead_error + + # Return error for unsupported features. + if request.suffix is not None: + return self.create_error_response( + "suffix is not currently supported") + + model_name = self.base_model_paths[0].name + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.time()) + + request_metadata = RequestResponseMetadata(request_id=request_id) + if raw_request: + raw_request.state.request_metadata = request_metadata + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[RequestOutput, None]] = [] + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + prompts = list( + self._tokenize_prompt_input_or_inputs( + request, + tokenizer, + request.prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + )) + + for i, prompt_inputs in enumerate(prompts): + sampling_params: Union[SamplingParams, BeamSearchParams] + default_max_tokens = self.max_model_len - len( + prompt_inputs["prompt_token_ids"]) + if request.use_beam_search: + sampling_params = request.to_beam_search_params( + default_max_tokens) + else: + sampling_params = request.to_sampling_params( + default_max_tokens) + + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + prompt_inputs, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + is_tracing_enabled = (await + self.engine_client.is_tracing_enabled()) + trace_headers = None + if is_tracing_enabled: + trace_headers = extract_trace_headers(raw_request.headers) + if not is_tracing_enabled and contains_trace_headers( + raw_request.headers): + log_tracing_disabled_warning() + + if isinstance(sampling_params, BeamSearchParams): + assert isinstance(self.engine_client, + (AsyncLLMEngine, + MQLLMEngineClient)), \ + "Beam search is only supported with" \ + "AsyncLLMEngine and MQLLMEngineClient." + generator = self.engine_client.beam_search( + prompt_inputs["prompt_token_ids"], + request_id_item, + sampling_params, + ) + else: + generator = self.engine_client.generate( + { + "prompt_token_ids": + prompt_inputs["prompt_token_ids"] + }, + sampling_params, + request_id_item, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use + # beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + # Streaming response + if stream: + return self.completion_stream_generator( + request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts), + tokenizer=tokenizer, + request_metadata=request_metadata) + + # Non-streaming response + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) + try: + async for i, res in result_generator: + final_res_batch[i] = res + + for i, final_res in enumerate(final_res_batch): + assert final_res is not None + + # The output should contain the input text + # We did not pass it into vLLM engine to avoid being redundant + # with the inputs token IDs + if final_res.prompt is None: + final_res.prompt = prompts[i]["prompt"] + + final_res_batch_checked = cast(List[RequestOutput], + final_res_batch) + + response = self.request_output_to_completion_response( + final_res_batch_checked, + request, + request_id, + created_time, + model_name, + tokenizer, + request_metadata, + ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + if request.stream: + response_json = response.model_dump_json() + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response + + async def completion_stream_generator( + self, + request: CompletionRequest, + result_generator: AsyncIterator[Tuple[int, RequestOutput]], + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> AsyncGenerator[str, None]: + num_choices = 1 if request.n is None else request.n + previous_text_lens = [0] * num_choices * num_prompts + previous_num_tokens = [0] * num_choices * num_prompts + has_echoed = [False] * num_choices * num_prompts + num_prompt_tokens = [0] * num_prompts + + try: + async for prompt_idx, res in result_generator: + prompt_token_ids = res.prompt_token_ids + prompt_logprobs = res.prompt_logprobs + prompt_text = res.prompt + + # Prompt details are excluded from later streamed outputs + if res.prompt_token_ids is not None: + num_prompt_tokens[prompt_idx] = len(res.prompt_token_ids) + + delta_token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[ + int, Logprob]]]] + + for output in res.outputs: + i = output.index + prompt_idx * num_choices + # TODO(simon): optimize the performance by avoiding full + # text O(n^2) sending. + + assert request.max_tokens is not None + if request.echo and request.max_tokens == 0: + assert prompt_token_ids is not None + assert prompt_text is not None + # only return the prompt + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + has_echoed[i] = True + elif (request.echo and request.max_tokens > 0 + and not has_echoed[i]): + assert prompt_token_ids is not None + assert prompt_text is not None + assert prompt_logprobs is not None + # echo the prompt and first token + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *prompt_logprobs, + *(output.logprobs or []), + ] + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text + delta_token_ids = output.token_ids + out_logprobs = output.logprobs + + if request.logprobs is not None: + assert out_logprobs is not None, ( + "Did not output logprobs") + logprobs = self._create_completion_logprobs( + token_ids=delta_token_ids, + top_logprobs=out_logprobs, + num_output_top_logprobs=request.logprobs, + tokenizer=tokenizer, + initial_text_offset=previous_text_lens[i], + ) + else: + logprobs = None + + previous_text_lens[i] += len(output.text) + previous_num_tokens[i] += len(output.token_ids) + finish_reason = output.finish_reason + stop_reason = output.stop_reason + + chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + stop_reason=stop_reason, + ) + ]) + if (request.stream_options + and request.stream_options.include_usage): + if (request.stream_options.continuous_usage_stats + or output.finish_reason is not None): + prompt_tokens = num_prompt_tokens[prompt_idx] + completion_tokens = previous_num_tokens[i] + usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + if request.stream_options.continuous_usage_stats: + chunk.usage = usage + else: + chunk.usage = None + + response_json = chunk.model_dump_json(exclude_unset=False) + yield f"data: {response_json}\n\n" + + if (request.stream_options + and request.stream_options.include_usage): + final_usage_chunk = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[], + usage=usage, + ) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=False, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + total_prompt_tokens = sum(num_prompt_tokens) + total_completion_tokens = sum(previous_num_tokens) + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens) + + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + + def request_output_to_completion_response( + self, + final_res_batch: List[RequestOutput], + request: CompletionRequest, + request_id: str, + created_time: int, + model_name: str, + tokenizer: AnyTokenizer, + request_metadata: RequestResponseMetadata, + ) -> CompletionResponse: + choices: List[CompletionResponseChoice] = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + + for final_res in final_res_batch: + prompt_token_ids = final_res.prompt_token_ids + assert prompt_token_ids is not None + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + + token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[int, + Logprob]]]] + + for output in final_res.outputs: + assert request.max_tokens is not None + if request.echo and request.max_tokens == 0: + assert prompt_text is not None + token_ids = prompt_token_ids + out_logprobs = prompt_logprobs + output_text = prompt_text + elif request.echo and request.max_tokens > 0: + assert prompt_text is not None + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert prompt_logprobs is not None + assert output.logprobs is not None + out_logprobs = [ + *prompt_logprobs, + *output.logprobs, + ] + + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + out_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + assert out_logprobs is not None, "Did not output logprobs" + logprobs = self._create_completion_logprobs( + token_ids=token_ids, + top_logprobs=out_logprobs, + tokenizer=tokenizer, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + prompt_logprobs=final_res.prompt_logprobs, + ) + choices.append(choice_data) + + num_generated_tokens += len(output.token_ids) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + request_metadata.final_usage_info = usage + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: int, + tokenizer: AnyTokenizer, + initial_text_offset: int = 0, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: List[int] = [] + out_token_logprobs: List[Optional[float]] = [] + out_tokens: List[str] = [] + out_top_logprobs: List[Optional[Dict[str, float]]] = [] + + last_token_len = 0 + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = tokenizer.decode(token_id) + if self.return_tokens_as_token_ids: + token = f"token_id:{token_id}" + + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + step_token = step_top_logprobs[token_id] + + token = self._get_decoded_token( + step_token, + token_id, + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids, + ) + token_logprob = max(step_token.logprob, -9999.0) + + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token( + top_lp[1], + top_lp[0], + tokenizer, + return_as_token_id=self.return_tokens_as_token_ids): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py new file mode 100644 index 0000000..e9504cf --- /dev/null +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -0,0 +1,203 @@ +import asyncio +import base64 +import time +from typing import AsyncGenerator, List, Literal, Optional, Union, cast + +import numpy as np +from fastapi import Request +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (EmbeddingRequest, + EmbeddingResponse, + EmbeddingResponseData, + ErrorResponse, UsageInfo) +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.logger import init_logger +from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] + + +def _get_embedding( + output: EmbeddingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[List[float], str]: + if encoding_format == "float": + return output.embedding + elif encoding_format == "base64": + # Force to use float32 for base64 encoding + # to match the OpenAI python client behavior + embedding_bytes = np.array(output.embedding, dtype="float32").tobytes() + return base64.b64encode(embedding_bytes).decode("utf-8") + + assert_never(encoding_format) + + +def request_output_to_embedding_response( + final_res_batch: List[EmbeddingRequestOutput], request_id: str, + created_time: int, model_name: str, + encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: + data: List[EmbeddingResponseData] = [] + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + prompt_token_ids = final_res.prompt_token_ids + embedding = _get_embedding(final_res.outputs, encoding_format) + embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) + data.append(embedding_data) + + num_prompt_tokens += len(prompt_token_ids) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return EmbeddingResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +class OpenAIServingEmbedding(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + request_logger: Optional[RequestLogger], + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger) + self._enabled = self._check_embedding_mode(model_config.embedding_mode) + + async def create_embedding( + self, + request: EmbeddingRequest, + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/embeddings/create + for the API specification. This API mimics the OpenAI Embedding API. + """ + if not self._enabled: + return self.create_error_response("Embedding API disabled") + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + encoding_format = request.encoding_format + if request.dimensions is not None: + return self.create_error_response( + "dimensions is currently not supported") + + model_name = request.model + request_id = f"embd-{random_uuid()}" + created_time = int(time.monotonic()) + + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + pooling_params = request.to_pooling_params() + + prompts = list( + self._tokenize_prompt_input_or_inputs(request, tokenizer, + request.input, + truncate_prompt_tokens)) + + for i, prompt_inputs in enumerate(prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + prompt_inputs, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported " + "for embedding models") + + generator = self.engine_client.encode( + {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, + pooling_params, + request_id_item, + lora_request=lora_request, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) + + # Non-streaming response + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + final_res_batch[i] = res + + for final_res in final_res_batch: + assert final_res is not None + + final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch) + + response = request_output_to_embedding_response( + final_res_batch_checked, request_id, created_time, model_name, + encoding_format) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response + + def _check_embedding_mode(self, embedding_mode: bool) -> bool: + if not embedding_mode: + logger.warning( + "embedding_mode is False. Embedding API will not work.") + else: + logger.info("Activating the server engine with embedding enabled.") + return embedding_mode diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py new file mode 100644 index 0000000..e6d2ab9 --- /dev/null +++ b/vllm/entrypoints/openai/serving_engine.py @@ -0,0 +1,487 @@ +import json +import pathlib +from dataclasses import dataclass +from http import HTTPStatus +from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union + +from pydantic import Field +from typing_extensions import Annotated + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + CompletionRequest, + DetokenizeRequest, + EmbeddingRequest, ErrorResponse, + LoadLoraAdapterRequest, + ModelCard, ModelList, + ModelPermission, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeRequest, + UnloadLoraAdapterRequest) +# yapf: enable +from vllm.inputs.parse import parse_and_batch_prompt +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import BeamSearchParams, SamplingParams +from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AtomicCounter + +logger = init_logger(__name__) + + +@dataclass +class BaseModelPath: + name: str + model_path: str + + +@dataclass +class PromptAdapterPath: + name: str + local_path: str + + +@dataclass +class LoRAModulePath: + name: str + path: str + base_model_name: Optional[str] = None + + +AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, + EmbeddingRequest, TokenizeRequest] + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + +class OpenAIServing: + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + return_tokens_as_token_ids: bool = False, + ): + super().__init__() + + self.engine_client = engine_client + self.model_config = model_config + self.max_model_len = model_config.max_model_len + + self.base_model_paths = base_model_paths + + self.lora_id_counter = AtomicCounter(0) + self.lora_requests = [] + if lora_modules is not None: + self.lora_requests = [ + LoRARequest(lora_name=lora.name, + lora_int_id=i, + lora_path=lora.path, + base_model_name=lora.base_model_name + if lora.base_model_name + and self._is_model_supported(lora.base_model_name) + else self.base_model_paths[0].name) + for i, lora in enumerate(lora_modules, start=1) + ] + + self.prompt_adapter_requests = [] + if prompt_adapters is not None: + for i, prompt_adapter in enumerate(prompt_adapters, start=1): + with pathlib.Path(prompt_adapter.local_path, + "adapter_config.json").open() as f: + adapter_config = json.load(f) + num_virtual_tokens = adapter_config["num_virtual_tokens"] + self.prompt_adapter_requests.append( + PromptAdapterRequest( + prompt_adapter_name=prompt_adapter.name, + prompt_adapter_id=i, + prompt_adapter_local_path=prompt_adapter.local_path, + prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + + self.request_logger = request_logger + self.return_tokens_as_token_ids = return_tokens_as_token_ids + + async def show_available_models(self) -> ModelList: + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=base_model.name, + max_model_len=self.max_model_len, + root=base_model.model_path, + permission=[ModelPermission()]) + for base_model in self.base_model_paths + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=lora.local_path, + parent=lora.base_model_name if lora.base_model_name else + self.base_model_paths[0].name, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + prompt_adapter_cards = [ + ModelCard(id=prompt_adapter.prompt_adapter_name, + root=self.base_model_paths[0].name, + permission=[ModelPermission()]) + for prompt_adapter in self.prompt_adapter_requests + ] + model_cards.extend(lora_cards) + model_cards.extend(prompt_adapter_cards) + return ModelList(data=model_cards) + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + def create_streaming_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> str: + json_str = json.dumps({ + "error": + self.create_error_response(message=message, + err_type=err_type, + status_code=status_code).model_dump() + }) + return json_str + + async def _check_model( + self, + request: AnyRequest, + ) -> Optional[ErrorResponse]: + if self._is_model_supported(request.model): + return None + if request.model in [lora.lora_name for lora in self.lora_requests]: + return None + if request.model in [ + prompt_adapter.prompt_adapter_name + for prompt_adapter in self.prompt_adapter_requests + ]: + return None + return self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + def _maybe_get_adapters( + self, request: AnyRequest + ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ + None, PromptAdapterRequest]]: + if self._is_model_supported(request.model): + return None, None + for lora in self.lora_requests: + if request.model == lora.lora_name: + return lora, None + for prompt_adapter in self.prompt_adapter_requests: + if request.model == prompt_adapter.prompt_adapter_name: + return None, prompt_adapter + # if _check_model has been called earlier, this will be unreachable + raise ValueError(f"The model `{request.model}` does not exist.") + + def _normalize_prompt_text_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt: str, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + add_special_tokens: bool, + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) + else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + input_ids = encoded.input_ids + + input_text = prompt + + return self._validate_input(request, input_ids, input_text) + + def _normalize_prompt_tokens_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_ids: List[int], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + input_ids = prompt_ids + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + + input_text = tokenizer.decode(input_ids) + + return self._validate_input(request, input_ids, input_text) + + def _validate_input( + self, + request: AnyRequest, + input_ids: List[int], + input_text: str, + ) -> TextTokensPrompt: + token_num = len(input_ids) + + # Note: EmbeddingRequest doesn't have max_tokens + if isinstance(request, EmbeddingRequest): + if token_num > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the input for embedding " + f"generation. Please reduce the length of the input.") + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens + # and does not require model context length validation + if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest)): + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) + + if request.max_tokens is None: + if token_num >= self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{token_num} tokens in the messages, " + f"Please reduce the length of the messages.") + elif token_num + request.max_tokens > self.max_model_len: + raise ValueError( + f"This model's maximum context length is " + f"{self.max_model_len} tokens. However, you requested " + f"{request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.") + + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + def _tokenize_prompt_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_input: Union[str, List[int]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> TextTokensPrompt: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes single input. + """ + return next( + self._tokenize_prompt_inputs( + request, + tokenizer, + [prompt_input], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + )) + + def _tokenize_prompt_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_inputs: Iterable[Union[str, List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes multiple inputs. + """ + for text in prompt_inputs: + if isinstance(text, str): + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=text, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _tokenize_prompt_input_or_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + Tokenize/detokenize depending on the input format. + + According to `OpenAI API `_ + , each input can be a string or array of tokens. Note that each request + can pass one or more inputs. + """ + for prompt_input in parse_and_batch_prompt(input_or_inputs): + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + if prompt_input["is_tokens"] is False: + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _log_inputs( + self, + request_id: str, + inputs: Union[str, List[int], TextTokensPrompt], + params: Optional[Union[SamplingParams, PoolingParams, + BeamSearchParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + if self.request_logger is None: + return + + if isinstance(inputs, str): + prompt = inputs + prompt_token_ids = None + elif isinstance(inputs, list): + prompt = None + prompt_token_ids = inputs + else: + prompt = inputs["prompt"] + prompt_token_ids = inputs["prompt_token_ids"] + + self.request_logger.log_inputs( + request_id, + prompt, + prompt_token_ids, + params=params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + @staticmethod + def _get_decoded_token(logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + return_as_token_id: bool = False) -> str: + if return_as_token_id: + return f"token_id:{token_id}" + + if logprob.decoded_token is not None: + return logprob.decoded_token + return tokenizer.decode(token_id) + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return self.create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return self.create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." + + def _is_model_supported(self, model_name): + return any(model.name == model_name for model in self.base_model_paths) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py new file mode 100644 index 0000000..a269c94 --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -0,0 +1,157 @@ +from typing import List, Optional, Union + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.chat_utils import (apply_hf_chat_template, + apply_mistral_chat_template, + load_chat_template, + parse_chat_messages_futures) +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.entrypoints.openai.protocol import (DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, + TokenizeResponse) +# yapf: enable +from vllm.entrypoints.openai.serving_engine import (BaseModelPath, + LoRAModulePath, + OpenAIServing) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +class OpenAIServingTokenization(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + lora_modules: Optional[List[LoRAModulePath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + ): + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=lora_modules, + prompt_adapters=None, + request_logger=request_logger) + + # If this is None we use the tokenizer's default chat template + # the list of commonly-used chat template names for HF named templates + hf_chat_templates: List[str] = ['default', 'tool_use'] + self.chat_template = chat_template \ + if chat_template in hf_chat_templates \ + else load_chat_template(chat_template) + + async def create_tokenize( + self, + request: TokenizeRequest, + ) -> Union[TokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{random_uuid()}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + prompt: Union[str, List[int]] + if isinstance(request, TokenizeChatRequest): + model_config = self.model_config + + conversation, mm_data_future = parse_chat_messages_futures( + request.messages, model_config, tokenizer) + + mm_data = await mm_data_future + if mm_data: + logger.warning( + "Multi-modal inputs are ignored during tokenization") + + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + continue_final_message=request.continue_final_message, + ) + else: + prompt = request.prompt + + self._log_inputs(request_id, + prompt, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + # Silently ignore prompt adapter since it does not affect tokenization + + prompt_input = self._tokenize_prompt_input( + request, + tokenizer, + prompt, + add_special_tokens=request.add_special_tokens, + ) + input_ids = prompt_input["prompt_token_ids"] + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, + request: DetokenizeRequest, + ) -> Union[DetokenizeResponse, ErrorResponse]: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + request_id = f"tokn-{random_uuid()}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for tokenization") + + prompt_input = self._tokenize_prompt_input( + request, + tokenizer, + request.tokens, + ) + input_text = prompt_input["prompt"] + + return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000..309d9be --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,10 @@ +from .abstract_tool_parser import ToolParser, ToolParserManager +from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser +from .llama_tool_parser import Llama3JsonToolParser +from .mistral_tool_parser import MistralToolParser + +__all__ = [ + "ToolParser", "ToolParserManager", "Hermes2ProToolParser", + "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser" +] diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..79baa12 Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-310.pyc new file mode 100644 index 0000000..ae2f3f3 Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/abstract_tool_parser.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-310.pyc new file mode 100644 index 0000000..82faaba Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/hermes_tool_parser.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-310.pyc new file mode 100644 index 0000000..e4db40f Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/internlm2_tool_parser.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-310.pyc new file mode 100644 index 0000000..ce23591 Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/llama_tool_parser.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-310.pyc new file mode 100644 index 0000000..b3cdc69 Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/mistral_tool_parser.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-310.pyc b/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..6d33704 Binary files /dev/null and b/vllm/entrypoints/openai/tool_parsers/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000..5ce31bd --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,161 @@ +import importlib +import importlib.util +import os +from functools import cached_property +from typing import Callable, Dict, List, Optional, Sequence, Type, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import is_list_of + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [] + + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> Dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") + + +class ToolParserManager: + tool_parsers: Dict[str, Type] = {} + + @classmethod + def get_tool_parser(cls, name) -> Type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + spec = importlib.util.spec_from_file_location(module_name, plugin_path) + if spec is None or spec.loader is None: + logger.error("load %s from %s failed.", module_name, plugin_path) + return + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000..bcbcda3 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,338 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module("hermes") +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id = self.vocab.get( + self.tool_call_start_token) + self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token) + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", + e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): + logger.debug("Generating text content! skipping tool parsing.") + if delta_text != self.tool_call_end_token: + return DeltaMessage(content=delta_text) + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], "") + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + args_delta_start_loc = cur_arguments_json.index(delta_text) \ + + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between\n%s", cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += argument_diff + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 0000000..905ab7d --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,208 @@ +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_argments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + # if the tool call is sended, return a empty delta message + # to make sure the finish_reason will be send correctly. + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() + action = action.split('<|action_end|>'.strip())[0] + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls + try: + tool_call_arr: Dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.get_argments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_argments(tool_call_arr) + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = self.get_argments(tool_call_arr) + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {}))) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py new file mode 100644 index 0000000..3cf34bc --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -0,0 +1,277 @@ +import json +import re +from json import JSONDecodeError, JSONDecoder +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix +from vllm.logger import init_logger +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +# partial_json_parser doesn't support extra data and +# JSONDecorder.raw_decode doesn't support partial JSON +def partial_json_loads(input_str, flags): + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + else: + raise + + +def is_complete_json(input_str): + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +@ToolParserManager.register_module("llama3_json") +class Llama3JsonToolParser(ToolParser): + """ + Tool call parser for Llama 3.1 models intended for use with the + examples/tool_chat_template_llama.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "<|python_tag|>" + self.bot_token_id = tokenizer.encode(self.bot_token, + add_special_tokens=False)[0] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. + """ + # case -- if a tool call token is not present, return a text response + if not (model_output.startswith(self.bot_token) + or model_output.startswith('{')): + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + try: + # load the JSON, and then use it to build the Function and + # Tool Call + dec = JSONDecoder() + function_call_arr = [] + + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if model_output.startswith( + self.bot_token) else 0 + while start_idx < len(model_output): + (obj, end_idx) = dec.raw_decode(model_output[start_idx:]) + start_idx += end_idx + len('; ') + function_call_arr.append(obj) + + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"] \ + if "arguments" in raw_function_call \ + else raw_function_call["parameters"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + ret = ExtractedToolCallInformation(tools_called=True, + tool_calls=tool_calls, + content=None) + return ret + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", e) + print("ERROR", e) + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + if not (current_text.startswith(self.bot_token) + or current_text.startswith('{')): + return DeltaMessage(content=delta_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = len(self.bot_token) if current_text.startswith( + self.bot_token) else 0 + while start_idx < len(current_text): + (obj, + end_idx) = partial_json_loads(current_text[start_idx:], + flags) + is_complete.append( + is_complete_json(current_text[start_idx:start_idx + + end_idx])) + start_idx += end_idx + len('; ') + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert "arguments" not in obj, \ + "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + delta = None + + if cur_arguments: + sent = len( + self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = find_common_prefix( + prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000..c6dc068 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,306 @@ +import json +import re +from random import choices +from string import ascii_letters, digits +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow +from pydantic import Field + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + +ALPHANUMERIC = ascii_letters + digits + + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + + +@ToolParserManager.register_module("mistral") +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if not isinstance(self.model_tokenizer, MistralTokenizer): + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.vocab.get(self.bot_token) + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + if not self.bot_token_id: + raise RuntimeError( + "Mistral Tool Parser could not locate the tool call token in " + "the tokenizer!") + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + + # use a regex to find the tool call. remove the BOT token + # and make sure to replace single quotes with double quotes + raw_tool_call = self.tool_call_regex.findall( + model_output.replace(self.bot_token, ""))[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[MistralToolCall] = [ + MistralToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", e) + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[-1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000..db7fc52 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,87 @@ +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string, substring): + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices diff --git a/vllm/envs.py b/vllm/envs.py new file mode 100644 index 0000000..4c9b4ae --- /dev/null +++ b/vllm/envs.py @@ -0,0 +1,446 @@ +import os +import tempfile +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional + +if TYPE_CHECKING: + VLLM_HOST_IP: str = "" + VLLM_PORT: Optional[int] = None + VLLM_RPC_BASE_PATH: str = tempfile.gettempdir() + VLLM_USE_MODELSCOPE: bool = False + VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 + VLLM_INSTANCE_ID: Optional[str] = None + VLLM_NCCL_SO_PATH: Optional[str] = None + LD_LIBRARY_PATH: Optional[str] = None + VLLM_USE_TRITON_FLASH_ATTN: bool = False + LOCAL_RANK: int = 0 + CUDA_VISIBLE_DEVICES: Optional[str] = None + VLLM_ENGINE_ITERATION_TIMEOUT_S: int = 60 + VLLM_API_KEY: Optional[str] = None + S3_ACCESS_KEY_ID: Optional[str] = None + S3_SECRET_ACCESS_KEY: Optional[str] = None + S3_ENDPOINT_URL: Optional[str] = None + VLLM_CACHE_ROOT: str = os.path.expanduser("~/.cache/vllm") + VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm") + VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai" + VLLM_NO_USAGE_STATS: bool = False + VLLM_DO_NOT_TRACK: bool = False + VLLM_USAGE_SOURCE: str = "" + VLLM_CONFIGURE_LOGGING: int = 1 + VLLM_LOGGING_LEVEL: str = "INFO" + VLLM_LOGGING_CONFIG_PATH: Optional[str] = None + VLLM_TRACE_FUNCTION: int = 0 + VLLM_ATTENTION_BACKEND: Optional[str] = None + VLLM_USE_FLASHINFER_SAMPLER: bool = False + VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False + VLLM_PP_LAYER_PARTITION: Optional[str] = None + VLLM_CPU_KVCACHE_SPACE: int = 0 + VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_OPENVINO_DEVICE: str = "CPU" + VLLM_OPENVINO_KVCACHE_SPACE: int = 0 + VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None + VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False + VLLM_XLA_CACHE_PATH: str = os.path.join(VLLM_CACHE_ROOT, "xla_cache") + VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 + VLLM_USE_RAY_SPMD_WORKER: bool = False + VLLM_USE_RAY_COMPILED_DAG: bool = False + VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True + VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" + VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") + VLLM_IMAGE_FETCH_TIMEOUT: int = 5 + VLLM_AUDIO_FETCH_TIMEOUT: int = 5 + VLLM_TARGET_DEVICE: str = "cuda" + MAX_JOBS: Optional[str] = None + NVCC_THREADS: Optional[str] = None + VLLM_USE_PRECOMPILED: bool = False + VLLM_NO_DEPRECATION_WARNING: bool = False + VLLM_KEEP_ALIVE_ON_ENGINE_DEATH: bool = False + CMAKE_BUILD_TYPE: Optional[str] = None + VERBOSE: bool = False + VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False + VLLM_TEST_FORCE_FP8_MARLIN: bool = False + VLLM_RPC_TIMEOUT: int = 10000 # ms + VLLM_PLUGINS: Optional[List[str]] = None + VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_USE_TRITON_AWQ: bool = False + VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False + VLLM_SKIP_P2P_CHECK: bool = False + VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False + VLLM_TORCH_COMPILE_LEVEL: int = 0 + + +def get_default_cache_root(): + return os.getenv( + "XDG_CACHE_HOME", + os.path.join(os.path.expanduser("~"), ".cache"), + ) + + +def get_default_config_root(): + return os.getenv( + "XDG_CONFIG_HOME", + os.path.join(os.path.expanduser("~"), ".config"), + ) + + +# The begin-* and end* here are used by the documentation generator +# to extract the used env vars. + +# begin-env-vars-definition + +environment_variables: Dict[str, Callable[[], Any]] = { + + # ================== Installation Time Env Vars ================== + + # Target device of vLLM, supporting [cuda (by default), + # rocm, neuron, cpu, openvino] + "VLLM_TARGET_DEVICE": + lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda"), + + # Maximum number of compilation jobs to run in parallel. + # By default this is the number of CPUs + "MAX_JOBS": + lambda: os.getenv("MAX_JOBS", None), + + # Number of threads to use for nvcc + # By default this is 1. + # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. + "NVCC_THREADS": + lambda: os.getenv("NVCC_THREADS", None), + + # If set, vllm will use precompiled binaries (*.so) + "VLLM_USE_PRECOMPILED": + lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")), + + # CMake build type + # If not set, defaults to "Debug" or "RelWithDebInfo" + # Available options: "Debug", "Release", "RelWithDebInfo" + "CMAKE_BUILD_TYPE": + lambda: os.getenv("CMAKE_BUILD_TYPE"), + + # If set, vllm will print verbose logs during installation + "VERBOSE": + lambda: bool(int(os.getenv('VERBOSE', '0'))), + + # Root directory for VLLM configuration files + # Defaults to `~/.config/vllm` unless `XDG_CONFIG_HOME` is set + # Note that this not only affects how vllm finds its configuration files + # during runtime, but also affects how vllm installs its configuration + # files during **installation**. + "VLLM_CONFIG_ROOT": + lambda: os.path.expanduser( + os.getenv( + "VLLM_CONFIG_ROOT", + os.path.join(get_default_config_root(), "vllm"), + )), + + # ================== Runtime Env Vars ================== + + # Root directory for VLLM cache files + # Defaults to `~/.cache/vllm` unless `XDG_CACHE_HOME` is set + "VLLM_CACHE_ROOT": + lambda: os.path.expanduser( + os.getenv( + "VLLM_CACHE_ROOT", + os.path.join(get_default_cache_root(), "vllm"), + )), + + # used in distributed environment to determine the ip address + # of the current node, when the node has multiple network interfaces. + # If you are using multi-node inference, you should set this differently + # on each node. + 'VLLM_HOST_IP': + lambda: os.getenv('VLLM_HOST_IP', "") or os.getenv("HOST_IP", ""), + + # used in distributed environment to manually set the communication port + # Note: if VLLM_PORT is set, and some code asks for multiple ports, the + # VLLM_PORT will be used as the first port, and the rest will be generated + # by incrementing the VLLM_PORT value. + # '0' is used to make mypy happy + 'VLLM_PORT': + lambda: int(os.getenv('VLLM_PORT', '0')) + if 'VLLM_PORT' in os.environ else None, + + # path used for ipc when the frontend api server is running in + # multi-processing mode to communicate with the backend engine process. + 'VLLM_RPC_BASE_PATH': + lambda: os.getenv('VLLM_RPC_BASE_PATH', tempfile.gettempdir()), + + # If true, will load models from ModelScope instead of Hugging Face Hub. + # note that the value is true or false, not numbers + "VLLM_USE_MODELSCOPE": + lambda: os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true", + + # Instance id represents an instance of the VLLM. All processes in the same + # instance should have the same instance id. + "VLLM_INSTANCE_ID": + lambda: os.environ.get("VLLM_INSTANCE_ID", None), + + # Interval in seconds to log a warning message when the ring buffer is full + "VLLM_RINGBUFFER_WARNING_INTERVAL": + lambda: int(os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")), + + # path to cudatoolkit home directory, under which should be bin, include, + # and lib directories. + "CUDA_HOME": + lambda: os.environ.get("CUDA_HOME", None), + + # Path to the NCCL library file. It is needed because nccl>=2.19 brought + # by PyTorch contains a bug: https://github.com/NVIDIA/nccl/issues/1234 + "VLLM_NCCL_SO_PATH": + lambda: os.environ.get("VLLM_NCCL_SO_PATH", None), + + # when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl + # library file in the locations specified by `LD_LIBRARY_PATH` + "LD_LIBRARY_PATH": + lambda: os.environ.get("LD_LIBRARY_PATH", None), + + # flag to control if vllm should use triton flash attention + "VLLM_USE_TRITON_FLASH_ATTN": + lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in + ("true", "1")), + + # Internal flag to enable Dynamo fullgraph capture + "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE": + lambda: bool( + os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + "VLLM_TORCH_COMPILE_LEVEL": + lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")), + + # local rank of the process in the distributed setting, used to determine + # the GPU device id + "LOCAL_RANK": + lambda: int(os.environ.get("LOCAL_RANK", "0")), + + # used to control the visible devices in the distributed setting + "CUDA_VISIBLE_DEVICES": + lambda: os.environ.get("CUDA_VISIBLE_DEVICES", None), + + # timeout for each iteration in the engine + "VLLM_ENGINE_ITERATION_TIMEOUT_S": + lambda: int(os.environ.get("VLLM_ENGINE_ITERATION_TIMEOUT_S", "60")), + + # API key for VLLM API server + "VLLM_API_KEY": + lambda: os.environ.get("VLLM_API_KEY", None), + + # S3 access information, used for tensorizer to load model from S3 + "S3_ACCESS_KEY_ID": + lambda: os.environ.get("S3_ACCESS_KEY_ID", None), + "S3_SECRET_ACCESS_KEY": + lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), + "S3_ENDPOINT_URL": + lambda: os.environ.get("S3_ENDPOINT_URL", None), + + # Usage stats collection + "VLLM_USAGE_STATS_SERVER": + lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"), + "VLLM_NO_USAGE_STATS": + lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", + "VLLM_DO_NOT_TRACK": + lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get( + "DO_NOT_TRACK", None) or "0") == "1", + "VLLM_USAGE_SOURCE": + lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), + + # Logging configuration + # If set to 0, vllm will not configure logging + # If set to 1, vllm will configure logging using the default configuration + # or the configuration file specified by VLLM_LOGGING_CONFIG_PATH + "VLLM_CONFIGURE_LOGGING": + lambda: int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")), + "VLLM_LOGGING_CONFIG_PATH": + lambda: os.getenv("VLLM_LOGGING_CONFIG_PATH"), + + # this is used for configuring the default logging level + "VLLM_LOGGING_LEVEL": + lambda: os.getenv("VLLM_LOGGING_LEVEL", "INFO"), + + # Trace function calls + # If set to 1, vllm will trace function calls + # Useful for debugging + "VLLM_TRACE_FUNCTION": + lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), + + # Backend for attention computation + # Available options: + # - "TORCH_SDPA": use torch.nn.MultiheadAttention + # - "FLASH_ATTN": use FlashAttention + # - "XFORMERS": use XFormers + # - "ROCM_FLASH": use ROCmFlashAttention + # - "FLASHINFER": use flashinfer + "VLLM_ATTENTION_BACKEND": + lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), + + # If set, vllm will use flashinfer sampler + "VLLM_USE_FLASHINFER_SAMPLER": + lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))), + + # Pipeline stage partition strategy + "VLLM_PP_LAYER_PARTITION": + lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), + + # (CPU backend only) CPU key-value cache space. + # default is 4GB + "VLLM_CPU_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")), + + # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", + # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. + "VLLM_CPU_OMP_THREADS_BIND": + lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), + + # OpenVINO device selection + # default is CPU + "VLLM_OPENVINO_DEVICE": + lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(), + + # OpenVINO key-value cache space + # default is 4GB + "VLLM_OPENVINO_KVCACHE_SPACE": + lambda: int(os.getenv("VLLM_OPENVINO_KVCACHE_SPACE", "0")), + + # OpenVINO KV cache precision + # default is bf16 if natively supported by platform, otherwise f16 + # To enable KV cache compression, please, explicitly specify u8 + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION": + lambda: os.getenv("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION", None), + + # Enables weights compression during model export via HF Optimum + # default is False + "VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS": + lambda: bool(os.getenv("VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS", False)), + + # If the env var is set, then all workers will execute as separate + # processes from the engine, and we use the same mechanism to trigger + # execution on all workers. + # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. + "VLLM_USE_RAY_SPMD_WORKER": + lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))), + + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + "VLLM_USE_RAY_COMPILED_DAG": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))), + + # If the env var is set, it uses NCCL for communication in + # Ray's compiled DAG. This flag is ignored if + # VLLM_USE_RAY_COMPILED_DAG is not set. + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": + lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1")) + ), + + # Use dedicated multiprocess context for workers. + # Both spawn and fork work + "VLLM_WORKER_MULTIPROC_METHOD": + lambda: os.getenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn"), + + # Path to the cache for storing downloaded assets + "VLLM_ASSETS_CACHE": + lambda: os.path.expanduser( + os.getenv( + "VLLM_ASSETS_CACHE", + os.path.join(get_default_cache_root(), "vllm", "assets"), + )), + + # Timeout for fetching images when serving multimodal models + # Default is 5 seconds + "VLLM_IMAGE_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), + + # Timeout for fetching audio when serving multimodal models + # Default is 5 seconds + "VLLM_AUDIO_FETCH_TIMEOUT": + lambda: int(os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "5")), + + # Path to the XLA persistent cache directory. + # Only used for XLA devices such as TPUs. + "VLLM_XLA_CACHE_PATH": + lambda: os.path.expanduser( + os.getenv( + "VLLM_XLA_CACHE_PATH", + os.path.join(get_default_cache_root(), "vllm", "xla_cache"), + )), + "VLLM_FUSED_MOE_CHUNK_SIZE": + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + + # If set, vllm will skip the deprecation warnings. + "VLLM_NO_DEPRECATION_WARNING": + lambda: bool(int(os.getenv("VLLM_NO_DEPRECATION_WARNING", "0"))), + + # If set, the OpenAI API server will stay alive even after the underlying + # AsyncLLMEngine errors and stops serving requests + "VLLM_KEEP_ALIVE_ON_ENGINE_DEATH": + lambda: bool(os.getenv("VLLM_KEEP_ALIVE_ON_ENGINE_DEATH", 0)), + + # If the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN is set, it allows + # the user to specify a max sequence length greater than + # the max length derived from the model's config.json. + # To enable this, set VLLM_ALLOW_LONG_MAX_MODEL_LEN=1. + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": + lambda: + (os.environ.get("VLLM_ALLOW_LONG_MAX_MODEL_LEN", "0").strip().lower() in + ("1", "true")), + + # If set, forces FP8 Marlin to be used for FP8 quantization regardless + # of the hardware support for FP8 compute. + "VLLM_TEST_FORCE_FP8_MARLIN": + lambda: + (os.environ.get("VLLM_TEST_FORCE_FP8_MARLIN", "0").strip().lower() in + ("1", "true")), + "VLLM_TEST_FORCE_LOAD_FORMAT": + lambda: os.getenv("VLLM_TEST_FORCE_LOAD_FORMAT", "dummy"), + + # Time in ms for the zmq client to wait for a response from the backend + # server for simple data operations + "VLLM_RPC_TIMEOUT": + lambda: int(os.getenv("VLLM_RPC_TIMEOUT", "10000")), + + # a list of plugin names to load, separated by commas. + # if this is not set, it means all plugins will be loaded + # if this is set to an empty string, no plugins will be loaded + "VLLM_PLUGINS": + lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[ + "VLLM_PLUGINS"].split(","), + + # Enables torch profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "VLLM_TORCH_PROFILER_DIR": + lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os + .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + + # If set, allow loading or unloading lora adapters in runtime, + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": + lambda: + (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in + ("1", "true")), + + # By default, vLLM will check the peer-to-peer capability itself, + # in case of broken drivers. See https://github.com/vllm-project/vllm/blob/a9b15c606fea67a072416ea0ea115261a2756058/vllm/distributed/device_communicators/custom_all_reduce_utils.py#L101-L108 for details. # noqa + # If this env var is set to 1, vLLM will skip the peer-to-peer check, + # and trust the driver's peer-to-peer capability report. + "VLLM_SKIP_P2P_CHECK": + lambda: os.getenv("VLLM_SKIP_P2P_CHECK", "0") == "1", + + # If set, allowing the use of deprecated block manager V1 + "VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1": + lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0" + ) == "1", +} + +# end-env-vars-definition + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in environment_variables: + return environment_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(environment_variables.keys()) diff --git a/vllm/executor/__init__.py b/vllm/executor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/executor/__pycache__/__init__.cpython-310.pyc b/vllm/executor/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..06de5ea Binary files /dev/null and b/vllm/executor/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/cpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/cpu_executor.cpython-310.pyc new file mode 100644 index 0000000..0032ddf Binary files /dev/null and b/vllm/executor/__pycache__/cpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/distributed_gpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/distributed_gpu_executor.cpython-310.pyc new file mode 100644 index 0000000..ae7c76d Binary files /dev/null and b/vllm/executor/__pycache__/distributed_gpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/executor_base.cpython-310.pyc b/vllm/executor/__pycache__/executor_base.cpython-310.pyc new file mode 100644 index 0000000..ccef798 Binary files /dev/null and b/vllm/executor/__pycache__/executor_base.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/gpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/gpu_executor.cpython-310.pyc new file mode 100644 index 0000000..42aa355 Binary files /dev/null and b/vllm/executor/__pycache__/gpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/msgspec_utils.cpython-310.pyc b/vllm/executor/__pycache__/msgspec_utils.cpython-310.pyc new file mode 100644 index 0000000..d824251 Binary files /dev/null and b/vllm/executor/__pycache__/msgspec_utils.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/multiproc_gpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/multiproc_gpu_executor.cpython-310.pyc new file mode 100644 index 0000000..57a1ff8 Binary files /dev/null and b/vllm/executor/__pycache__/multiproc_gpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/multiproc_worker_utils.cpython-310.pyc b/vllm/executor/__pycache__/multiproc_worker_utils.cpython-310.pyc new file mode 100644 index 0000000..b092866 Binary files /dev/null and b/vllm/executor/__pycache__/multiproc_worker_utils.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/multiproc_xpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/multiproc_xpu_executor.cpython-310.pyc new file mode 100644 index 0000000..4b31776 Binary files /dev/null and b/vllm/executor/__pycache__/multiproc_xpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/neuron_executor.cpython-310.pyc b/vllm/executor/__pycache__/neuron_executor.cpython-310.pyc new file mode 100644 index 0000000..2b14ed3 Binary files /dev/null and b/vllm/executor/__pycache__/neuron_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/openvino_executor.cpython-310.pyc b/vllm/executor/__pycache__/openvino_executor.cpython-310.pyc new file mode 100644 index 0000000..79fbeb9 Binary files /dev/null and b/vllm/executor/__pycache__/openvino_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/ray_gpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/ray_gpu_executor.cpython-310.pyc new file mode 100644 index 0000000..2aaa6fe Binary files /dev/null and b/vllm/executor/__pycache__/ray_gpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/ray_tpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/ray_tpu_executor.cpython-310.pyc new file mode 100644 index 0000000..7ac3d79 Binary files /dev/null and b/vllm/executor/__pycache__/ray_tpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/ray_utils.cpython-310.pyc b/vllm/executor/__pycache__/ray_utils.cpython-310.pyc new file mode 100644 index 0000000..a509882 Binary files /dev/null and b/vllm/executor/__pycache__/ray_utils.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/ray_xpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/ray_xpu_executor.cpython-310.pyc new file mode 100644 index 0000000..b78487c Binary files /dev/null and b/vllm/executor/__pycache__/ray_xpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/tpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/tpu_executor.cpython-310.pyc new file mode 100644 index 0000000..02db6b6 Binary files /dev/null and b/vllm/executor/__pycache__/tpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/__pycache__/xpu_executor.cpython-310.pyc b/vllm/executor/__pycache__/xpu_executor.cpython-310.pyc new file mode 100644 index 0000000..651acb1 Binary files /dev/null and b/vllm/executor/__pycache__/xpu_executor.cpython-310.pyc differ diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py new file mode 100644 index 0000000..e32993e --- /dev/null +++ b/vllm/executor/cpu_executor.py @@ -0,0 +1,389 @@ +import os +from functools import partial +from typing import Any, Awaitable, List, Optional, Set, Tuple, Union + +import torch + +import vllm.envs as envs +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, + get_vllm_instance_id, make_async) +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) + + +class CPUExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert self.device_config.device_type == "cpu" + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + assert self.lora_config is None, "cpu backend doesn't support LoRA" + + # + # Environment variables for CPU executor + # + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Intel OpenMP setting + ld_prealod_str = os.getenv("LD_PRELOAD", "") + if "libiomp5.so" in ld_prealod_str: + # The time(milliseconds) that a thread should wait after + # completing the execution of a parallel region, before sleeping. + os.environ['KMP_BLOCKTIME'] = "1" + # Prevents the CPU to run into low performance state + os.environ['KMP_TPAUSE'] = "0" + # Provides fine granularity parallelism + os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + + # To hint IPEX uses shared memory based AllReduce + os.environ["LOCAL_WORLD_SIZE"] = str( + self.parallel_config.tensor_parallel_size) + + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.scheduler_config = _verify_and_get_scheduler_config( + self.scheduler_config) + self.parallel_config = _verify_and_get_parallel_config( + self.parallel_config) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + ip = "127.0.0.1" + port = get_open_port() + self.distributed_init_method = get_distributed_init_method(ip, port) + + is_async = isinstance(self, CPUExecutorAsync) + + world_size = self.parallel_config.tensor_parallel_size + result_handler = ResultHandler() + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + self.workers = [] + + if is_async: + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + )) for rank in range(0, world_size) + ] + self.driver_worker = self.workers[0] + self.workers = self.workers[1:] + self.driver_method_invoker = _async_driver_method_invoker + else: + self.driver_worker = self._create_worker() + self.driver_method_invoker = _driver_method_invoker + + if world_size != 1: + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = None + if world_size != 1 or is_async: + if is_async: + async_worker_list = self.workers + [self.driver_worker] + else: + async_worker_list = self.workers + self.worker_monitor = WorkerMonitor(async_worker_list, + result_handler) + result_handler.start() + self.worker_monitor.start() + + self._run_workers("init_device") + self._run_workers("load_model") + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + ): + worker_module_name = "vllm.worker.cpu_worker" + worker_class_name = "CPUWorker" + + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + ) + + assert self.distributed_init_method is not None + + kwargs = dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=self.distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, + is_driver_worker=rank == 0, + ) + wrapper.init_worker(**kwargs) + + return wrapper.worker + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if async_run_remote_workers_only: + # Just return futures + return worker_outputs + + driver_worker_output = self.driver_method_invoker( + self.driver_worker, method, *args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_method_invoker(self.driver_worker, + "determine_num_available_blocks") + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + # NOTE: `cpu block` for CPU backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info("# CPU blocks: %d", num_gpu_blocks) + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if (self.parallel_config.tensor_parallel_size > 1 + and self.parallel_worker_tasks is None): + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + ) + output = self.driver_method_invoker(self.driver_worker, + "execute_model", execute_model_req) + return output + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + """ + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + self.driver_method_invoker(self.driver_worker, "execute_model", None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return all(self._run_workers("add_lora", lora_request)) + + def remove_lora(self, lora_id: int) -> bool: + return all(self._run_workers("remove_lora", lora_id)) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return all(self._run_workers( + "pin_lora", + lora_id=lora_id, + )) + + def list_loras(self) -> Set[int]: + return self.driver_method_invoker(self.driver_worker, "list_loras") + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return all( + self._run_workers( + "add_prompt_adapter", + prompt_adapter_request, + )) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return all( + self._run_workers( + "remove_prompt_adapter", + prompt_adapter_id, + )) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_method_invoker(self.driver_worker, + "list_prompt_adapters") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return all(self._run_workers( + "pin_prompt_adapter", + prompt_adapter_id, + )) + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + def start_profile(self) -> None: + self.driver_method_invoker(self.driver_worker, "start_profile") + + def stop_profile(self) -> None: + self.driver_method_invoker(self.driver_worker, "stop_profile") + + +class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + self.check_health() + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype == torch.float16: + logger.warning("float16 is not supported on CPU, casting to bfloat16.") + config.dtype = torch.bfloat16 + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on CPU, fallback to the eager " + "mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_scheduler_config( + config: SchedulerConfig) -> SchedulerConfig: + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if config.chunked_prefill_enabled: + logger.warning("Chunked prefill is not supported on CPU, disable it.") + config.chunked_prefill_enabled = False + + return config + + +def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if config.enable_prefix_caching: + logger.warning("Prefix caching is not supported on CPU, disable it.") + config.enable_prefix_caching = False + + kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE + + if kv_cache_space >= 0: + if kv_cache_space == 0: + config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) " + "for CPU backend is not set, using 4 by default.") + else: + config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_CPU_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config + + +def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: + if (config.distributed_executor_backend is not None + and config.distributed_executor_backend != "mp"): + logger.warning( + "%s is not supported on CPU, fallback to mp distributed executor " + "backend.", config.distributed_executor_backend) + config.distributed_executor_backend = "mp" + return config + + +def _driver_method_invoker(driver, method: str, *args, **kwargs): + return getattr(driver, method)(*args, **kwargs) + + +def _async_driver_method_invoker(driver, method: str, *args, **kwargs): + return driver.execute_method(method, *args, **kwargs).get() diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py new file mode 100644 index 0000000..deb7cb1 --- /dev/null +++ b/vllm/executor/distributed_gpu_executor.py @@ -0,0 +1,212 @@ +import asyncio +from abc import abstractmethod +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest + +logger = init_logger(__name__) + + +class DistributedGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + max_concurrency = (num_gpu_blocks * self.cache_config.block_size / + self.model_config.max_model_len) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, max_concurrency) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_tensor_parallel_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) + + # Only the driver worker returns the sampling results. + driver_outputs = self._driver_execute_model(execute_model_req) + assert driver_outputs is not None + return driver_outputs + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model(execute_model_req=None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "pin_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self._run_workers("save_sharded_state", + path=path, + pattern=pattern, + max_size=max_size) + + @abstractmethod + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution loop + running in each of the remote workers. In this case, this method + returns None. Otherwise, this method returns the model output. + """ + raise NotImplementedError + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + raise NotImplementedError + + +class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + @abstractmethod + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py new file mode 100644 index 0000000..c96cb0f --- /dev/null +++ b/vllm/executor/executor_base.py @@ -0,0 +1,150 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Set, Tuple + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest + + +class ExecutorBase(ABC): + """Base class for all executors. + + An executor is responsible for executing the model on a specific device + type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor + that can execute the model on multiple devices. + """ + + uses_ray: bool # whether the executor uses Ray for orchestration. + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + observability_config: Optional[ObservabilityConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + self._init_executor() + + @abstractmethod + def _init_executor(self) -> None: + pass + + @abstractmethod + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + Normally, this should simply delegate to the underlying Worker. Some + ExecutorBase may require modification of the result, e.g. to ensure the + selected cache sizes are compatible with all workers. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + @abstractmethod + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences.""" + raise NotImplementedError + + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError # type: ignore + + @abstractmethod + def list_loras(self) -> Set[int]: + raise NotImplementedError + + @abstractmethod + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError # type: ignore + + @abstractmethod + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError + + @abstractmethod + def check_health(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError + + def shutdown(self) -> None: + """Shutdown the executor.""" + return + + def __del__(self): + self.shutdown() + + +class ExecutorAsyncBase(ExecutorBase): + + @abstractmethod + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + self.check_health() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py new file mode 100644 index 0000000..ed30d31 --- /dev/null +++ b/vllm/executor/gpu_executor.py @@ -0,0 +1,191 @@ +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) +from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase + +logger = init_logger(__name__) + + +def create_worker(worker_module_name: str, worker_class_name: str, + worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], + **kwargs): + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + ) + wrapper.init_worker(**kwargs) + return wrapper.worker + + +class GPUExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + speculative_config=self.speculative_config, + prompt_adapter_config=self.prompt_adapter_config, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + observability_config=self.observability_config, + ) + + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" + elif self.speculative_config: + worker_module_name = "vllm.spec_decode.spec_decode_worker" + worker_class_name = "create_spec_worker" + else: + worker_module_name = "vllm.worker.worker" + worker_class_name = "Worker" + return (worker_module_name, worker_class_name, worker_class_fn) + + def _get_create_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict: + worker_kwargs = self._get_worker_kwargs(local_rank, rank, + distributed_init_method) + + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() + worker_kwargs.update( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + ) + + return worker_kwargs + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + return create_worker(**self._get_create_worker_kwargs( + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method)) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + max_concurrency = (num_gpu_blocks * self.cache_config.block_size / + self.model_config.max_model_len) + logger.info("Maximum concurrency for %s tokens per request: %.2fx", + self.model_config.max_model_len, max_concurrency) + + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + + def check_health(self) -> None: + # GPUExecutor will always be healthy as long as + # it's running. + return + + def start_profile(self) -> None: + self.driver_worker.start_profile() + + def stop_profile(self) -> None: + self.driver_worker.stop_profile() + + +class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[Union[SamplerOutput, PoolerOutput]]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req) + return output diff --git a/vllm/executor/msgspec_utils.py b/vllm/executor/msgspec_utils.py new file mode 100644 index 0000000..c467115 --- /dev/null +++ b/vllm/executor/msgspec_utils.py @@ -0,0 +1,27 @@ +from array import array +from typing import Any, Type + +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE + + +def encode_hook(obj: Any) -> Any: + """Custom msgspec enc hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if isinstance(obj, array): + assert obj.typecode == VLLM_TOKEN_ID_ARRAY_TYPE, ( + f"vLLM array type should use '{VLLM_TOKEN_ID_ARRAY_TYPE}' type. " + f"Given array has a type code of {obj.typecode}.") + return obj.tobytes() + + +def decode_hook(type: Type, obj: Any) -> Any: + """Custom msgspec dec hook that supports array types. + + See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder + """ + if type is array: + deserialized = array(VLLM_TOKEN_ID_ARRAY_TYPE) + deserialized.frombytes(obj) + return deserialized diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py new file mode 100644 index 0000000..2dbde77 --- /dev/null +++ b/vllm/executor/multiproc_gpu_executor.py @@ -0,0 +1,258 @@ +import asyncio +import os +from functools import partial +from typing import Any, List, Optional + +import torch + +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.gpu_executor import create_worker +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.triton_utils import maybe_set_triton_cache_manager +from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, + cuda_is_initialized, get_distributed_init_method, + get_open_port, get_vllm_instance_id, make_async, + update_environment_variables) + +logger = init_logger(__name__) + + +class MultiprocessingGPUExecutor(DistributedGPUExecutor): + """Python multiprocessing-based multi-GPU executor""" + + uses_ray: bool = False + + def _init_executor(self) -> None: + self._check_executor_parameters() + + # Create the parallel GPU workers. + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) + + # workaround for https://github.com/vllm-project/vllm/issues/6103 + if world_size > 1: + maybe_set_triton_cache_manager() + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + self.workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[ProcessWorkerWrapper] = [] + + if world_size == 1: + self.worker_monitor = None + else: + result_handler = ResultHandler() + for rank in range(1, world_size): + worker = ProcessWorkerWrapper( + result_handler, + partial( + create_worker, + **self._get_create_worker_kwargs( + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + ))) + self.workers.append(worker) + if rank % tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def _check_executor_parameters(self): + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + if (cuda_is_initialized() + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"): + logger.warning("CUDA was previously initialized. We must use " + "the `spawn` multiprocessing start method. Setting " + "VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model(execute_model_req) + + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if async_run_tensor_parallel_workers_only: + # Run only non-driver workers and just return futures. + return [ + worker.execute_method(method, *args, **kwargs) + for worker in self.non_driver_workers + ] + + # Start all remote workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + +class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, + DistributedGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) + self.pp_locks: Optional[List[asyncio.Lock]] = None + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if not self.tp_driver_workers: + return await self.driver_exec_model(execute_model_req) + + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_model, self.pp_locks[0], + execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method_async, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/executor/multiproc_worker_utils.py b/vllm/executor/multiproc_worker_utils.py new file mode 100644 index 0000000..e14ecc1 --- /dev/null +++ b/vllm/executor/multiproc_worker_utils.py @@ -0,0 +1,274 @@ +import asyncio +import multiprocessing +import os +import sys +import threading +import traceback +import uuid +from dataclasses import dataclass +from multiprocessing import Queue +from multiprocessing.connection import wait +from multiprocessing.process import BaseProcess +from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO, + TypeVar, Union) + +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + +T = TypeVar('T') + +_TERMINATE = "TERMINATE" # sentinel + +# ANSI color codes +CYAN = '\033[1;36m' +RESET = '\033[0;0m' + +JOIN_TIMEOUT_S = 2 + + +@dataclass +class Result(Generic[T]): + """Result of task dispatched to worker""" + + task_id: uuid.UUID + value: Optional[T] = None + exception: Optional[BaseException] = None + + +class ResultFuture(threading.Event, Generic[T]): + """Synchronous future for non-async case""" + + def __init__(self): + super().__init__() + self.result: Optional[Result[T]] = None + + def set_result(self, result: Result[T]): + self.result = result + self.set() + + def get(self) -> T: + self.wait() + assert self.result is not None + if self.result.exception is not None: + raise self.result.exception + return self.result.value # type: ignore[return-value] + + +def _set_future_result(future: Union[ResultFuture, asyncio.Future], + result: Result): + if isinstance(future, ResultFuture): + future.set_result(result) + return + loop = future.get_loop() + if not loop.is_closed(): + if result.exception is not None: + loop.call_soon_threadsafe(future.set_exception, result.exception) + else: + loop.call_soon_threadsafe(future.set_result, result.value) + + +class ResultHandler(threading.Thread): + """Handle results from all workers (in background thread)""" + + def __init__(self) -> None: + super().__init__(daemon=True) + self.result_queue = get_mp_context().Queue() + self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} + + def run(self): + for result in iter(self.result_queue.get, _TERMINATE): + future = self.tasks.pop(result.task_id) + _set_future_result(future, result) + # Ensure that all waiters will receive an exception + for task_id, future in self.tasks.items(): + _set_future_result( + future, + Result(task_id=task_id, + exception=ChildProcessError("worker died"))) + + def close(self): + self.result_queue.put(_TERMINATE) + + +class WorkerMonitor(threading.Thread): + """Monitor worker status (in background thread)""" + + def __init__(self, workers: List['ProcessWorkerWrapper'], + result_handler: ResultHandler): + super().__init__(daemon=True) + self.workers = workers + self.result_handler = result_handler + self._close = False + + def run(self) -> None: + # Blocks until any worker exits + dead_sentinels = wait([w.process.sentinel for w in self.workers]) + if not self._close: + self._close = True + + # Kill / cleanup all workers + for worker in self.workers: + process = worker.process + if process.sentinel in dead_sentinels: + process.join(JOIN_TIMEOUT_S) + if process.exitcode is not None and process.exitcode != 0: + logger.error("Worker %s pid %s died, exit code: %s", + process.name, process.pid, process.exitcode) + # Cleanup any remaining workers + if logger: + logger.info("Killing local vLLM worker processes") + for worker in self.workers: + worker.kill_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + for worker in self.workers: + worker.process.join(JOIN_TIMEOUT_S) + + def close(self): + if self._close: + return + self._close = True + logger.info("Terminating local vLLM worker processes") + for worker in self.workers: + worker.terminate_worker() + # Must be done after worker task queues are all closed + self.result_handler.close() + + +class ProcessWorkerWrapper: + """Local process wrapper for vllm.worker.Worker, + for handling single-node multi-GPU tensor parallel.""" + + def __init__(self, result_handler: ResultHandler, + worker_factory: Callable[[], Any]) -> None: + self.mp = get_mp_context() + self._task_queue = self.mp.Queue() + self.result_queue = result_handler.result_queue + self.tasks = result_handler.tasks + self.process: BaseProcess = self.mp.Process( # type: ignore[attr-defined] + target=_run_worker_process, + name="VllmWorkerProcess", + kwargs=dict( + worker_factory=worker_factory, + task_queue=self._task_queue, + result_queue=self.result_queue, + ), + daemon=True) + + self.process.start() + + def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], + method: str, args, kwargs): + task_id = uuid.uuid4() + self.tasks[task_id] = future + try: + self._task_queue.put((task_id, method, args, kwargs)) + except SystemExit: + raise + except BaseException as e: + del self.tasks[task_id] + raise ChildProcessError("worker died") from e + + def execute_method(self, method: str, *args, **kwargs): + future: ResultFuture = ResultFuture() + self._enqueue_task(future, method, args, kwargs) + return future + + async def execute_method_async(self, method: str, *args, **kwargs): + future = asyncio.get_running_loop().create_future() + self._enqueue_task(future, method, args, kwargs) + return await future + + def terminate_worker(self): + try: + self._task_queue.put(_TERMINATE) + except ValueError: + self.process.kill() + self._task_queue.close() + + def kill_worker(self): + self._task_queue.close() + self.process.kill() + + +def _run_worker_process( + worker_factory: Callable[[], Any], + task_queue: Queue, + result_queue: Queue, +) -> None: + """Worker process event loop""" + + # Add process-specific prefix to stdout and stderr + process_name = get_mp_context().current_process().name + pid = os.getpid() + _add_prefix(sys.stdout, process_name, pid) + _add_prefix(sys.stderr, process_name, pid) + + # Initialize worker + worker = worker_factory() + del worker_factory + + # Accept tasks from the engine in task_queue + # and return task output in result_queue + logger.info("Worker ready; awaiting tasks") + try: + for items in iter(task_queue.get, _TERMINATE): + output = None + exception = None + task_id, method, args, kwargs = items + try: + executor = getattr(worker, method) + output = executor(*args, **kwargs) + except SystemExit: + raise + except KeyboardInterrupt: + break + except BaseException as e: + tb = traceback.format_exc() + logger.error( + "Exception in worker %s while processing method %s: %s, %s", + process_name, method, e, tb) + exception = e + result_queue.put( + Result(task_id=task_id, value=output, exception=exception)) + except KeyboardInterrupt: + pass + except Exception: + logger.exception("Worker failed") + + logger.info("Worker exiting") + + +def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: + """Prepend each output line with process-specific prefix""" + + prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " + file_write = file.write + + def write_with_prefix(s: str): + if not s: + return + if file.start_new_line: # type: ignore[attr-defined] + file_write(prefix) + idx = 0 + while (next_idx := s.find('\n', idx)) != -1: + next_idx += 1 + file_write(s[idx:next_idx]) + if next_idx == len(s): + file.start_new_line = True # type: ignore[attr-defined] + return + file_write(prefix) + idx = next_idx + file_write(s[idx:]) + file.start_new_line = False # type: ignore[attr-defined] + + file.start_new_line = True # type: ignore[attr-defined] + file.write = write_with_prefix # type: ignore[method-assign] + + +def get_mp_context(): + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + return multiprocessing.get_context(mp_method) diff --git a/vllm/executor/multiproc_xpu_executor.py b/vllm/executor/multiproc_xpu_executor.py new file mode 100644 index 0000000..a66afbf --- /dev/null +++ b/vllm/executor/multiproc_xpu_executor.py @@ -0,0 +1,26 @@ +import vllm.envs as envs +from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) +from vllm.executor.xpu_executor import XPUExecutor +from vllm.logger import init_logger +from vllm.utils import make_async + +logger = init_logger(__name__) + + +class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor): + """Python multiprocessing-based multi-XPU executor""" + + def _check_executor_parameters(self): + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + if mp_method != "spawn": + raise RuntimeError( + "XPU multiprocess executor only support spawn as mp method") + + +class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor, + MultiprocessingGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py new file mode 100644 index 0000000..f2fcfa5 --- /dev/null +++ b/vllm/executor/neuron_executor.py @@ -0,0 +1,115 @@ +from typing import List, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +logger = init_logger(__name__) + + +class NeuronExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Neuron backend." + assert (not self.speculative_config + ), "Speculative decoding not yet supported for Neuron backend." + + # Instantiate the worker and load the model to the device. + self._init_worker() + + def _init_worker(self): + from vllm.worker.neuron_worker import NeuronWorker + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = NeuronWorker( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + assert (not execute_model_req.blocks_to_swap_in + and not execute_model_req.blocks_to_swap_out + and not execute_model_req.blocks_to_copy), ( + "Cache operations are not supported for Neuron backend.") + assert execute_model_req.num_lookahead_slots == 0, ( + "lookahead not supported for Neuron backend.") + + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the Neuron backend.") + + def check_health(self) -> None: + # NeuronExecutor will always be healthy as long as + # it's running. + return + + +class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + # NeuronExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py new file mode 100644 index 0000000..4a39839 --- /dev/null +++ b/vllm/executor/openvino_executor.py @@ -0,0 +1,213 @@ +from typing import List, Set, Tuple + +import openvino as ov +import openvino.properties.hint as hints +import torch + +import vllm.envs as envs +from vllm.config import CacheConfig, ModelConfig +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, + get_open_port, make_async) + +logger = init_logger(__name__) + + +def is_openvino_cpu() -> bool: + return "CPU" in envs.VLLM_OPENVINO_DEVICE + + +def is_openvino_gpu() -> bool: + return "GPU" in envs.VLLM_OPENVINO_DEVICE + + +class OpenVINOExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert self.device_config.device_type == "openvino" + assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" + assert is_openvino_cpu() or is_openvino_gpu(), \ + "OpenVINO backend supports only CPU and GPU devices" + + self.ov_core = ov.Core() + self.model_config = _verify_and_get_model_config(self.model_config) + self.cache_config = _verify_and_get_cache_config( + self.ov_core, self.cache_config) + + # Instantiate the worker and load the model to CPU. + self._init_worker() + + def _init_worker(self): + from vllm.worker.openvino_worker import OpenVINOWorker + + assert ( + self.parallel_config.world_size == 1 + ), "OpenVINOExecutor only supports single CPU socket currently." + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = OpenVINOWorker( + ov_core=self.ov_core, + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker.""" + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + # NOTE: In case of a CPU device, `cpu block` for OpenVINO backend + # is located on CPU memory but is referred as `gpu block`. + # Because we want to reuse the existing block management procedure. + device_blocks = num_gpu_blocks + swap_blocks = num_cpu_blocks + logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d", + envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the OPENVINO backend.") + + def check_health(self) -> None: + # OpenVINOExecutor will always be healthy as long as + # it's running. + return + + +class OpenVINOExecutorAsync(OpenVINOExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + # OpenVINOExecutor will always be healthy as long as + # it's running. + return + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype != torch.float32: + logger.warning( + f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501 + ) + config.dtype = torch.float32 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on OpenVINO backend, fallback to the " + "eager mode.") + config.enforce_eager = True + return config + + +def _verify_and_get_cache_config(ov_core: ov.Core, + config: CacheConfig) -> CacheConfig: + if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": + if not is_openvino_cpu(): + logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" + "ignored for GPU, f16 data type will be used.") + config.cache_dtype = ov.Type.f16 + else: + logger.info("KV cache type is overridden to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + config.cache_dtype = ov.Type.u8 + else: + if is_openvino_cpu(): + ov_device = envs.VLLM_OPENVINO_DEVICE + inference_precision = ov_core.get_property( + ov_device, hints.inference_precision) + if inference_precision == ov.Type.bf16: + config.cache_dtype = ov.Type.bf16 + else: + config.cache_dtype = ov.Type.f16 + else: + config.cache_dtype = ov.Type.f16 + + if is_openvino_cpu(): + if config.block_size != 32: + logger.info( + f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 + ) + config.block_size = 32 + else: + if config.block_size != 16: + logger.info( + f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501 + ) + config.block_size = 16 + + kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE + if kv_cache_space >= 0: + if kv_cache_space == 0 and is_openvino_cpu(): + config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning( + "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " + "for OpenVINO backend is not set, using 4 by default.") + else: + config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore + else: + raise RuntimeError( + "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") + + return config diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py new file mode 100644 index 0000000..be2948a --- /dev/null +++ b/vllm/executor/ray_gpu_executor.py @@ -0,0 +1,586 @@ +import asyncio +import os +from collections import defaultdict +from itertools import islice, repeat +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import msgspec + +import vllm.envs as envs +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) +from vllm.executor.msgspec_utils import encode_hook +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (_run_task_with_lock, get_distributed_init_method, + get_ip, get_open_port, get_vllm_instance_id, + make_async) + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayGPUExecutor(DistributedGPUExecutor): + + uses_ray: bool = True + + def _init_executor(self) -> None: + self.forward_dag: Optional["ray.dag.CompiledDAG"] = None + # If the env var is set, it uses the Ray's compiled DAG API + # which optimizes the control plane overhead. + # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. + # Currently, this requires USE_RAY_SPMD_WORKER=True. + self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG + # If the env var is set, then we do not distinguish between the + # "driver worker" vs other workers. Also, the rank 0 worker will + # be executed in a remote Ray worker. Currently this requires + # USE_RAY_COMPILED_DAG=True. + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if self.use_ray_compiled_dag: + assert self.use_ray_spmd_worker, ( + "VLLM_USE_RAY_COMPILED_DAG=1 requires " + "VLLM_USE_RAY_SPMD_WORKER=1") + if self.use_ray_spmd_worker: + # TODO: Support SPMD worker for non-DAG Ray executor. + assert self.use_ray_compiled_dag, ( + "VLLM_USE_RAY_SPMD_WORKER=1 requires " + "VLLM_USE_RAY_COMPILED_DAG=1") + + assert self.uses_ray + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + self.output_decoder = msgspec.msgpack.Decoder( + Optional[List[SamplerOutput]]) + + def shutdown(self) -> None: + if hasattr(self, "forward_dag") and self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) + self.forward_dag = None + + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + + def _get_worker_wrapper_args(self) -> Dict[str, Any]: + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() + + return dict( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + trust_remote_code=self.model_config.trust_remote_code, + ) + + # child class could overwrite this to return actual env vars. + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): + # For single GPU case, we use a ray worker with constrained memory. + num_gpus = self.cache_config.gpu_memory_utilization + else: + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + driver_ip = get_ip() + worker_wrapper_kwargs = self._get_worker_wrapper_args() + # vllm_multi_node_nccl_id = os.environ.get("VLLM_MULTI_NODE_NCCL_COMM_ID",None) + nccl_socket_name = os.environ.get("NCCL_SOCKET_IFNAME",None) + runtime_env={} + if nccl_socket_name is not None: + runtime_env["env_vars"] = {"NCCL_SOCKET_IFNAME" :f"{nccl_socket_name}"} + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + runtime_env=runtime_env, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + + if self.use_ray_spmd_worker: + self.workers.append(worker) + else: + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + **worker_wrapper_kwargs) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + def _driver_execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] + ) -> Optional[List[SamplerOutput]]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return super().execute_model(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=False) + + serialized_data = self.input_encoder.encode(execute_model_req) + outputs = ray.get(self.forward_dag.execute(serialized_data)) + output = self.output_decoder.decode(outputs[0]) + return output + + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + if self.use_ray_spmd_worker: + assert not async_run_tensor_parallel_workers_only, ( + "async_run_tensor_parallel_workers_only is not supported for " + "spmd mode.") + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + count = len(self.workers) if not \ + async_run_tensor_parallel_workers_only \ + else len(self.non_driver_workers) + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1 + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, first_worker_args_index, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, first_worker_args_index, None) + + # Start the ray workers first. + ray_workers = self.workers + if async_run_tensor_parallel_workers_only: + ray_workers = self.non_driver_workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + ] + + if async_run_tensor_parallel_workers_only: + # Just return futures + return ray_worker_outputs + + driver_worker_output = [] + # In SPMD mode, the driver worker is the same as any other worker, + # so we only explicitly execute on the driver worker if using a + # non-SPMD worker class. + if not self.use_ray_spmd_worker: + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = [ + self.driver_worker.execute_method(method, *driver_args, + **driver_kwargs) + ] + else: + assert self.driver_dummy_worker is not None + driver_worker_output = [ + ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + ] + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return driver_worker_output + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def _check_ray_adag_installation(self): + import pkg_resources + from packaging import version + + required_version = version.parse("2.35") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) + # TODO: update the constraint once we adapt to the backward + # incompatible API change from ray 2.36 + if current_version != required_version: + raise ValueError(f"Ray version {required_version} is " + f"required, but found {current_version}") + + import importlib.util + adag_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if adag_spec is None: + raise ValueError("Ray accelerated DAG is not installed. " + "Run `pip install ray[adag]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." + "Run `pip install ray[adag]` and check cupy installation.") + + def _compiled_ray_dag(self, enable_asyncio: bool): + assert self.parallel_config.use_ray + self._check_ray_adag_installation() + from ray.dag import InputNode, MultiOutputNode + from ray.experimental.channel.torch_tensor_type import TorchTensorType + + logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", + envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) + with InputNode() as input_data: + # Example DAG: PP=2, TP=4 + # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501 + # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501 + # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501 + # -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501 + + # All workers in the first TP group will take in the + # ExecuteModelRequest as input. + outputs = [input_data for _ in self.pp_tp_workers[0]] + for pp_rank, tp_group in enumerate(self.pp_tp_workers): + # Each PP worker takes in the output of the previous PP worker, + # and the TP group executes in SPMD fashion. + outputs = [ + worker.execute_model_spmd. + bind( # type: ignore[attr-defined] + outputs[i]) for i, worker in enumerate(tp_group) + ] + + last_pp_rank = len(self.pp_tp_workers) - 1 + if pp_rank < last_pp_rank: + # Specify how intermediate tensors should be passed + # between pp stages, no need to specify for the last + # pp stage. + transport = "nccl" \ + if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \ + else "auto" + outputs = [ + output.with_type_hint( + TorchTensorType(transport=transport)) + for output in outputs + ] + + forward_dag = MultiOutputNode(outputs) + + return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) + + def __del__(self): + self.shutdown() + + +class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pp_locks: Optional[List[asyncio.Lock]] = None + self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER + if not self.use_ray_compiled_dag: + self.driver_exec_method = make_async( + self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if not self.use_ray_spmd_worker: + return await super().execute_model_async(execute_model_req) + + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag(enable_asyncio=True) + + serialized_data = self.input_encoder.encode(execute_model_req) + dag_future = await self.forward_dag.execute_async(serialized_data) + outputs = await dag_future + return self.output_decoder.decode(outputs[0]) + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + assert not self.use_ray_spmd_worker, ( + "driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1") + if not self.tp_driver_workers: + return await self.driver_exec_method("execute_model", + execute_model_req) + if self.pp_locks is None: + # This locks each pipeline parallel stage so multiple virtual + # engines can't execute on the same stage at the same time + # We create the locks here to avoid creating them in the constructor + # which uses a different asyncio loop. + self.pp_locks = [ + asyncio.Lock() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + tasks = [ + asyncio.create_task( + _run_task_with_lock(self.driver_exec_method, self.pp_locks[0], + "execute_model", execute_model_req)) + ] + for pp_rank, driver_worker in enumerate(self.tp_driver_workers, + start=1): + tasks.append( + asyncio.create_task( + _run_task_with_lock(driver_worker.execute_method.remote, + self.pp_locks[pp_rank], + "execute_model", execute_model_req))) + + results = await asyncio.gather(*tasks) + + # Only the last PP stage has the final results. + return results[-1] + + async def _start_worker_execution_loop(self): + assert not self.use_ray_spmd_worker, ( + "worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1") + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.non_driver_workers + ] + return await asyncio.gather(*coros) + + def __del__(self): + self.shutdown() diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py new file mode 100644 index 0000000..d02fecb --- /dev/null +++ b/vllm/executor/ray_tpu_executor.py @@ -0,0 +1,363 @@ +import asyncio +import os +from collections import defaultdict +from itertools import islice, repeat +from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple, + Union) + +import vllm.envs as envs +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.executor.tpu_executor import TPUExecutor +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + get_vllm_instance_id, make_async) + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayTPUExecutor(TPUExecutor): + + uses_ray: bool = True + + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + + def _init_executor(self) -> None: + assert self.parallel_config.distributed_executor_backend == "ray" + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel TPU workers. + self._init_workers_ray(placement_group) + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Create the workers. + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("TPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + assert self.speculative_config is None + if self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_tpu_worker" + worker_class_name = "MultiStepTPUWorker" + else: + worker_module_name = "vllm.worker.tpu_worker" + worker_class_name = "TPUWorker" + + # GKE does not fetch environment information from metadata server + # and instead sets these from within the Ray process. Therefore we + # need to override the Ray environment variables manually. + override_env = {} + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + + worker = ray.remote( + num_cpus=0, + resources={"TPU": 1}, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + if override_env: + worker.override_env_vars.remote(override_env) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + trust_remote_code=self.model_config.trust_remote_code, + ) + else: + # Else, added to the list of workers. + self.workers.append(worker) + + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any TPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "TPU node.") + + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of TPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + node_workers = defaultdict(list) + for i, (node_id, _) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for _ in worker_node_and_gpu_ids] + self._run_workers("update_environment_variables", + all_args=all_args_to_update_environment_variables) + + if len(node_workers) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + use_dummy_driver: bool = False, + max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 1, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 1, None) + + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + + # Start the driver worker after all the ray workers. + if not use_dummy_driver: + driver_worker_output = self.driver_worker.execute_method( + method, *driver_args, **driver_kwargs) + else: + assert self.driver_dummy_worker is not None + driver_worker_output = ray.get( + self.driver_dummy_worker.execute_method.remote( + method, *driver_args, **driver_kwargs)) + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_blocks = self._run_workers("determine_num_available_blocks", ) + num_tpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) + + # Only the driver worker returns the sampling results. + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + +class RayTPUExecutorAsync(RayTPUExecutor, ExecutorAsyncBase): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_method = make_async(self.driver_worker.execute_method) + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py new file mode 100644 index 0000000..37a0afe --- /dev/null +++ b/vllm/executor/ray_utils.py @@ -0,0 +1,336 @@ +import os +import time +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import msgspec + +from vllm.config import ParallelConfig +from vllm.executor.msgspec_utils import decode_hook, encode_hook +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.utils import get_ip, is_hip, is_xpu +from vllm.worker.worker_base import WorkerWrapperBase + +logger = init_logger(__name__) +PG_WAIT_TIMEOUT = 1800 + +try: + import ray + from ray.util import placement_group_table + from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node + + class RayWorkerWrapper(WorkerWrapperBase): + """Ray wrapper for vllm.worker.Worker, allowing Worker to be + lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. + self.compiled_dag_cuda_device_set = False + + self.input_decoder = msgspec.msgpack.Decoder(ExecuteModelRequest, + dec_hook=decode_hook) + self.output_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook) + + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + gpu_ids = ray.get_gpu_ids() + return node_id, gpu_ids + + def execute_model_spmd( + self, req_or_tuple: Union[bytes, + Tuple[bytes, + Optional[IntermediateTensors]]] + ) -> bytes: + """Execute model in SPMD fashion: used only when SPMD worker and + compiled DAG are both enabled. + + Args: + req_or_tuple: A request or a tuple containing the + request and intermediate tensors. Intermediate tensors are + None unless if it is provided because it is > 0 pipeline + stage. The request is serialized by msgspec. + """ + if isinstance(req_or_tuple, bytes): + serialized_req, intermediate_tensors = req_or_tuple, None + else: + serialized_req, intermediate_tensors = req_or_tuple + + execute_model_req = self.input_decoder.decode(serialized_req) + + # TODO(swang): This is needed right now because Ray aDAG executes + # on a background thread, so we need to reset torch's current + # device. + import torch + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + output = self.worker._execute_model_spmd(execute_model_req, + intermediate_tensors) + # Pipeline model request and output to the next pipeline stage. + if isinstance(output, IntermediateTensors): + output = serialized_req, output + else: + output = self.output_encoder.encode(output) + + return output + + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + + ray_import_err = None + +except ImportError as e: + ray = None # type: ignore + ray_import_err = e + RayWorkerWrapper = None # type: ignore + + +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """Raise an exception if Ray is not available.""" + if ray is None: + raise ValueError("Failed to import Ray, please install Ray with " + "`pip install ray`.") from ray_import_err + + +def _verify_bundles(placement_group: "PlacementGroup", + parallel_config: ParallelConfig, device_str: str): + """Verify a given placement group has bundles located in the right place. + + There are 2 rules. + - Warn if all tensor parallel workers cannot fit in a single node. + - Fail if driver node is not included in a placement group. + """ + assert ray.is_initialized(), ( + "Ray is not initialized although distributed-executor-backend is ray.") + pg_data = placement_group_table(placement_group) + # bundle_idx -> node_id + bundle_to_node_ids = pg_data["bundles_to_node_id"] + # bundle_idx -> bundle (e.g., {"GPU": 1}) + bundles = pg_data["bundles"] + # node_id -> List of bundle (e.g., {"GPU": 1}) + node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + + for bundle_idx, node_id in bundle_to_node_ids.items(): + node_id_to_bundle[node_id].append(bundles[bundle_idx]) + driver_node_id = ray.get_runtime_context().get_node_id() + + if driver_node_id not in node_id_to_bundle: + raise RuntimeError( + f"driver node id {driver_node_id} is not included in a placement " + f"group {placement_group.id}. Node id -> bundles " + f"{node_id_to_bundle}. " + "You don't have enough GPUs available in a current node. Check " + "`ray status` to see if you have available GPUs in a node " + f"{driver_node_id} before starting an vLLM engine.") + + for node_id, bundles in node_id_to_bundle.items(): + if len(bundles) < parallel_config.tensor_parallel_size: + logger.warning( + "tensor_parallel_size=%d " + "is bigger than a reserved number of %ss (%d " + "%ss) in a node %s. Tensor parallel workers can be " + "spread out to 2+ nodes which can degrade the performance " + "unless you have fast interconnect across nodes, like " + "Infiniband. To resolve this issue, make sure you have more " + "than %d GPUs available at each node.", + parallel_config.tensor_parallel_size, device_str, len(bundles), + device_str, node_id, parallel_config.tensor_parallel_size) + + +def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): + """Wait until a placement group is ready. + + It prints the informative log messages if the placement group is + not created within time. + + """ + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + placement_group_specs = current_placement_group.bundle_specs + + s = time.time() + pg_ready_ref = current_placement_group.ready() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval) + if len(ready) > 0: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for creating a placement group of specs for " + "%d seconds. specs=%s. Check " + "`ray status` to see if you have enough resources.", + int(time.time() - s), placement_group_specs) + + try: + ray.get(pg_ready_ref, timeout=0) + except ray.exceptions.GetTimeoutError: + raise ValueError( + "Cannot provide a placement group of " + f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " + "`ray status` to make sure the cluster has enough resources." + ) from None + + +def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): + ray.util.remove_placement_group(current_placement_group) + s = time.time() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + pg = ray.util.get_current_placement_group() + if pg is None: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for removing a placement group of specs for " + "%d seconds.", int(time.time() - s)) + time.sleep(wait_interval) + + +def initialize_ray_cluster( + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. + + Args: + parallel_config: The configurations for parallel execution. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + assert_ray_available() + + # Connect to a ray cluster. + if is_hip() or is_xpu(): + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + import torch + device_count = torch.cuda.device_count() + if device_count >= parallel_config.world_size: + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + # For multi-node case + ray.init(address=ray_address, + ignore_reinit_error=True) + + if parallel_config.placement_group: + # Placement group is already set. + return + + device_str = "GPU" if not current_platform.is_tpu() else "TPU" + # Create placement group for worker processes + current_placement_group = ray.util.get_current_placement_group() + if current_placement_group: + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + device_bundles = 0 + for bundle in bundles: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 " + f"{device_str}.") + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group." + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") + else: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") + # Create a new placement group + placement_group_specs: List[Dict[str, float]] = ([{ + device_str: 1.0 + } for _ in range(parallel_config.world_size)]) + + # vLLM engine is also a worker to execute model with an accelerator, + # so it requires to have the device in a current node. Check if + # the current node has at least one device. + current_ip = get_ip() + current_node_id = ray.get_runtime_context().get_node_id() + current_node_resource = available_resources_per_node()[current_node_id] + if current_node_resource.get(device_str, 0) < 1: + raise ValueError( + f"Current node has no {device_str} available. " + f"{current_node_resource=}. vLLM engine cannot start without " + f"{device_str}. Make sure you have at least 1 {device_str} " + f"available in a node {current_node_id=} {current_ip=}.") + # This way, at least bundle is required to be created in a current + # node. + placement_group_specs[0][f"node:{current_ip}"] = 0.001 + + # By default, Ray packs resources as much as possible. + current_placement_group = ray.util.placement_group( + placement_group_specs, strategy="PACK") + _wait_until_pg_ready(current_placement_group) + + assert current_placement_group is not None + _verify_bundles(current_placement_group, parallel_config, device_str) + # Set the placement group in the parallel config + parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py new file mode 100644 index 0000000..2b1cdc0 --- /dev/null +++ b/vllm/executor/ray_xpu_executor.py @@ -0,0 +1,37 @@ +import asyncio +from typing import List, Optional + +import vllm.envs as envs +from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync +from vllm.executor.xpu_executor import XPUExecutor +from vllm.logger import init_logger +from vllm.utils import get_vllm_instance_id, make_async + +logger = init_logger(__name__) + + +class RayXPUExecutor(RayGPUExecutor, XPUExecutor): + + def _get_env_vars_to_be_updated(self): + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + }, ) for (_, _) in worker_node_and_gpu_ids] + return all_args_to_update_environment_variables + + +class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_method = make_async(self.driver_worker.execute_method) + self.pp_locks: Optional[List[asyncio.Lock]] = None diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py new file mode 100644 index 0000000..972649d --- /dev/null +++ b/vllm/executor/tpu_executor.py @@ -0,0 +1,147 @@ +from typing import Any, Dict, List, Optional, Set, Tuple + +import torch + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) + +logger = init_logger(__name__) + + +class TPUExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert not self.scheduler_config.chunked_prefill_enabled, ( + "Chunked prefill is not yet supported for TPU backend") + assert not self.speculative_config, ( + "Speculative decoding is not yet supported for TPU backend") + if self.model_config.dtype in (torch.float16, torch.float32): + logger.warning( + "The TPU backend currently does not support %s. " + "Using bfloat16 instead.", self.model_config.dtype) + self.model_config.dtype = torch.bfloat16 + + # Instantiate the worker and load the model to the device. + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None, + ) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + load_config=self.load_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=rank == 0, + ) + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None, + ): + if self.scheduler_config.is_multi_step: + from vllm.worker.multi_step_tpu_worker import MultiStepTPUWorker + worker = MultiStepTPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker + else: + from vllm.worker.tpu_worker import TPUWorker + + worker = TPUWorker(**self._get_worker_kwargs( + local_rank, rank, distributed_init_method)) + return worker + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int, + ) -> None: + """Initialize the KV cache by invoking the underlying worker.""" + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker.""" + return self.driver_worker.determine_num_available_blocks() + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") + + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") + + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") + + def list_loras(self) -> Set[int]: + raise NotImplementedError( + "LoRA is currently not supported by the TPU backend.") + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the TPU backend.") + + def check_health(self) -> None: + # TPUExecutor will always be healthy as long as it's running. + return + + +class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + sexecute_model_req: ExecuteModelRequest, + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model + )(sexecute_model_req) + return output diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py new file mode 100644 index 0000000..bada560 --- /dev/null +++ b/vllm/executor/xpu_executor.py @@ -0,0 +1,96 @@ +from typing import Callable, List, Optional, Tuple, Type, Union + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput +from vllm.utils import make_async +from vllm.worker.worker_base import WorkerBase + +logger = init_logger(__name__) + + +class XPUExecutor(GPUExecutor): + + uses_ray: bool = False + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + speculative_config: Optional[SpeculativeConfig], + observability_config: Optional[ObservabilityConfig], + ) -> None: + assert device_config.device_type == "xpu" + assert (not speculative_config + ), "Speculative decoding not yet supported for XPU backend" + + model_config = _verify_and_get_model_config(model_config) + + self.model_config = model_config + self.cache_config = cache_config + self.load_config = load_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.prompt_adapter_config = prompt_adapter_config + self.speculative_config = None + self.observability_config = observability_config + + # Instantiate the worker and load the model to GPU. + self._init_executor() + + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None + if self.speculative_config is not None: + raise NotImplementedError( + "XPU does not support speculative decoding") + else: + worker_module_name = "vllm.worker.xpu_worker" + worker_class_name = "XPUWorker" + return (worker_module_name, worker_class_name, worker_class_fn) + + def execute_model( + self, execute_model_req: ExecuteModelRequest + ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: + output = self.driver_worker.execute_model(execute_model_req) + return output + + +class XPUExecutorAsync(XPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async(self.driver_worker.execute_model + )(execute_model_req=execute_model_req) + return output + + +def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: + if config.dtype == torch.bfloat16: + logger.warning( + "bfloat16 is not fully supported on XPU, casting to float16.") + config.dtype = torch.float16 + if not config.enforce_eager: + logger.warning( + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + config.enforce_eager = True + return config diff --git a/vllm/forward_context.py b/vllm/forward_context.py new file mode 100644 index 0000000..7777475 --- /dev/null +++ b/vllm/forward_context.py @@ -0,0 +1,22 @@ +from contextlib import contextmanager +from typing import Any + +_forward_context: Any = None + + +def get_forward_context() -> Any: + """Get the current forward context.""" + return _forward_context + + +@contextmanager +def set_forward_context(context: Any): + """A context manager that stores the current forward context, + can be attention metadata, etc.""" + global _forward_context + prev_context = _forward_context + _forward_context = context + try: + yield + finally: + _forward_context = prev_context diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py new file mode 100644 index 0000000..a8c8672 --- /dev/null +++ b/vllm/inputs/__init__.py @@ -0,0 +1,44 @@ +from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt, build_explicit_enc_dec_prompt, + to_enc_dec_tuple_list, zip_enc_dec_prompts) +from .registry import InputContext, InputRegistry + +INPUT_REGISTRY = InputRegistry() +""" +The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine` +to dispatch data processing according to the target model. + +See also: + :ref:`input_processing_pipeline` +""" + +__all__ = [ + "TextPrompt", + "TokensPrompt", + "PromptType", + "SingletonPrompt", + "ExplicitEncoderDecoderPrompt", + "LLMInputs", + "EncoderDecoderLLMInputs", + "build_explicit_enc_dec_prompt", + "to_enc_dec_tuple_list", + "zip_enc_dec_prompts", + "INPUT_REGISTRY", + "InputContext", + "InputRegistry", +] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/__pycache__/__init__.cpython-310.pyc b/vllm/inputs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..565b6e4 Binary files /dev/null and b/vllm/inputs/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/inputs/__pycache__/data.cpython-310.pyc b/vllm/inputs/__pycache__/data.cpython-310.pyc new file mode 100644 index 0000000..b69bc51 Binary files /dev/null and b/vllm/inputs/__pycache__/data.cpython-310.pyc differ diff --git a/vllm/inputs/__pycache__/parse.cpython-310.pyc b/vllm/inputs/__pycache__/parse.cpython-310.pyc new file mode 100644 index 0000000..f11cc04 Binary files /dev/null and b/vllm/inputs/__pycache__/parse.cpython-310.pyc differ diff --git a/vllm/inputs/__pycache__/preprocess.cpython-310.pyc b/vllm/inputs/__pycache__/preprocess.cpython-310.pyc new file mode 100644 index 0000000..1857735 Binary files /dev/null and b/vllm/inputs/__pycache__/preprocess.cpython-310.pyc differ diff --git a/vllm/inputs/__pycache__/registry.cpython-310.pyc b/vllm/inputs/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000..c03e8fd Binary files /dev/null and b/vllm/inputs/__pycache__/registry.cpython-310.pyc differ diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py new file mode 100644 index 0000000..724cdd2 --- /dev/null +++ b/vllm/inputs/data.py @@ -0,0 +1,242 @@ +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Tuple, Union) + +from typing_extensions import NotRequired, TypedDict, TypeVar + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + + +class TextPrompt(TypedDict): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[Dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + multi_modal_data: NotRequired["MultiModalDataDict"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[Dict[str, Any]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + +SingletonPrompt = Union[str, TextPrompt, TokensPrompt] +""" +Set of possible schemas for a single LLM input: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) + +Note that "singleton" is as opposed to a data structure +which encapsulates multiple prompts, i.e. of the sort +which may be utilized for encoder/decoder models when +the user desires to express both the encoder & decoder +prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` + +A prompt of type :class:`SingletonPrompt` may be employed +as (1) input to a decoder-only model, (2) input to +the encoder of an encoder/decoder model, in the scenario +where the decoder-prompt is not specified explicitly, or +(3) as a member of a larger data structure encapsulating +more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt` +""" + +_T1_co = TypeVar("_T1_co", + bound=SingletonPrompt, + default=SingletonPrompt, + covariant=True) +_T2_co = TypeVar("_T2_co", + bound=SingletonPrompt, + default=SingletonPrompt, + covariant=True) + + +# TODO: Make fields ReadOnly once mypy supports it +class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): + """ + Represents an encoder/decoder model input prompt, + comprising an explicit encoder prompt and a decoder prompt. + + The encoder and decoder prompts, respectively, may be formatted + according to any of the :class:`SingletonPrompt` schemas, + and are not required to have the same schema. + + Only the encoder prompt may have multi-modal data. mm_processor_kwargs + should be at the top-level, and should not be set in the encoder/decoder + prompts, since they are agnostic to the encoder/decoder. + + Note that an :class:`ExplicitEncoderDecoderPrompt` may not + be used as an input to a decoder-only model, + and that the :code:`encoder_prompt` and :code:`decoder_prompt` + fields of this data structure themselves must be + :class:`SingletonPrompt` instances. + """ + + encoder_prompt: _T1_co + + decoder_prompt: Optional[_T2_co] + + mm_processor_kwargs: NotRequired[Dict[str, Any]] + + +PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] +""" +Set of possible schemas for an LLM input, including +both decoder-only and encoder/decoder input types: + +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +- A single data structure containing both an encoder and a decoder prompt + (:class:`ExplicitEncoderDecoderPrompt`) +""" + + +class LLMInputs(TypedDict): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + + This specifies the data required for decoder-only models. + """ + prompt_token_ids: List[int] + """The token IDs of the prompt.""" + + prompt: NotRequired[Optional[str]] + """ + The original prompt text corresponding to the token IDs, if available. + """ + + multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + mm_processor_kwargs: NotRequired[Optional[Dict[str, Any]]] + """ + Optional multi-modal processor kwargs to be forwarded to the + multimodal input mapper & processor. Note that if multiple modalities + have registered mappers etc for the model being considered, we attempt + to pass the mm_processor_kwargs to each of them. + """ + + +class EncoderDecoderLLMInputs(LLMInputs): + """ + The inputs in :class:`~vllm.LLMEngine` before they are + passed to the model executor. + + This specifies the required data for encoder-decoder models. + """ + encoder_prompt_token_ids: List[int] + """The token IDs of the encoder prompt.""" + + encoder_prompt: NotRequired[Optional[str]] + """ + The original encoder prompt text corresponding to the token IDs, if + available. + """ + + encoder_multi_modal_data: NotRequired[Optional["MultiModalDataDict"]] + """ + Optional multi-modal data to pass to the encoder model, + if the model supports it. + """ + + +_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) +_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) + + +def build_explicit_enc_dec_prompt( + encoder_prompt: _T1, + decoder_prompt: Optional[_T2], + mm_processor_kwargs: Optional[Dict[str, Any]] = None, +) -> ExplicitEncoderDecoderPrompt[_T1, _T2]: + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return ExplicitEncoderDecoderPrompt( + encoder_prompt=encoder_prompt, + decoder_prompt=decoder_prompt, + mm_processor_kwargs=mm_processor_kwargs) + + +def zip_enc_dec_prompts( + enc_prompts: Iterable[_T1], + dec_prompts: Iterable[Optional[_T2]], + mm_processor_kwargs: Optional[Union[Iterable[Dict[str, Any]], + Dict[str, Any]]] = None, +) -> List[ExplicitEncoderDecoderPrompt[_T1, _T2]]: + """ + Zip encoder and decoder prompts together into a list of + :class:`ExplicitEncoderDecoderPrompt` instances. mm_processor_kwargs + may also be provided; if a dict is passed, the same dictionary will be + used for every encoder/decoder prompt. If an iterable is provided, it will + be zipped with the encoder/decoder prompts. + """ + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + if isinstance(mm_processor_kwargs, Dict): + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_processor_kwargs) + for (encoder_prompt, + decoder_prompt) in zip(enc_prompts, dec_prompts) + ] + return [ + build_explicit_enc_dec_prompt(encoder_prompt, decoder_prompt, + mm_proc_kwargs) + for (encoder_prompt, decoder_prompt, mm_proc_kwargs + ) in zip(enc_prompts, dec_prompts, mm_processor_kwargs) + ] + + +def to_enc_dec_tuple_list( + enc_dec_prompts: Iterable[ExplicitEncoderDecoderPrompt[_T1, _T2]], +) -> List[Tuple[_T1, Optional[_T2]]]: + return [(enc_dec_prompt["encoder_prompt"], + enc_dec_prompt["decoder_prompt"]) + for enc_dec_prompt in enc_dec_prompts] + + +def __getattr__(name: str): + if name == "PromptInput": + import warnings + + msg = ("PromptInput has been renamed to PromptType. " + "The original name will be removed in an upcoming version.") + + warnings.warn(DeprecationWarning(msg), stacklevel=2) + + return PromptType + + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/parse.py b/vllm/inputs/parse.py new file mode 100644 index 0000000..e5fa1e4 --- /dev/null +++ b/vllm/inputs/parse.py @@ -0,0 +1,106 @@ +from typing import List, Literal, Sequence, TypedDict, Union, overload + +from typing_extensions import TypeIs + +from vllm.utils import is_list_of + +from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, + LLMInputs, PromptType, SingletonPrompt, TextPrompt, + TokensPrompt) + + +class ParsedText(TypedDict): + content: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + content: List[int] + is_tokens: Literal[True] + + +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedText(content=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt, str): + # case 2: array of strings + return [ + ParsedText(content=elem, is_tokens=False) for elem in prompt + ] + if is_list_of(prompt, int): + # case 3: array of tokens + return [ParsedTokens(content=prompt, is_tokens=True)] + if is_list_of(prompt, list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if is_list_of(prompt[0], int): + # case 4: array of token arrays + return [ + ParsedTokens(content=elem, is_tokens=True) + for elem in prompt + ] + + raise TypeError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + +class ParsedStrPrompt(TypedDict): + type: Literal["str"] + content: str + + +class ParsedTextPrompt(TypedDict): + type: Literal["text"] + content: TextPrompt + + +class ParsedTokensPrompt(TypedDict): + type: Literal["tokens"] + content: TokensPrompt + + +def parse_singleton_prompt( + prompt: SingletonPrompt, +) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: + if isinstance(prompt, str): + return ParsedStrPrompt(type="str", content=prompt) + elif isinstance(prompt, dict): + if "prompt_token_ids" in prompt: + return ParsedTokensPrompt(type="tokens", + content=prompt) # type: ignore + elif "prompt" in prompt: + return ParsedTextPrompt(type="text", content=prompt) + + raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") + + +def is_explicit_encoder_decoder_prompt( + prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: + return isinstance(prompt, dict) and "encoder_prompt" in prompt + + +def is_valid_encoder_decoder_llm_inputs( + inputs: Union[LLMInputs, EncoderDecoderLLMInputs], +) -> TypeIs[EncoderDecoderLLMInputs]: + return "encoder_prompt_token_ids" in inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py new file mode 100644 index 0000000..64387fd --- /dev/null +++ b/vllm/inputs/preprocess.py @@ -0,0 +1,580 @@ +import asyncio +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from typing_extensions import assert_never + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup +from vllm.utils import print_warning_once + +from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, + SingletonPrompt) +from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt + +if TYPE_CHECKING: + from vllm.multimodal import MultiModalDataDict + +logger = init_logger(__name__) + +PromptComponents = Tuple[Optional[str], List[int], + Optional["MultiModalDataDict"], Optional[Dict[str, + Any]]] +DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], + Optional["MultiModalDataDict"], + Optional[Dict[str, Any]]] + + +class InputPreprocessor: + + def __init__( + self, + model_config: ModelConfig, + tokenizer: Optional[BaseTokenizerGroup], + ) -> None: + super().__init__() + + self.model_config = model_config + self.tokenizer = tokenizer + + def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.tokenizer is None: + raise ValueError("You cannot pass text prompts when " + "`skip_tokenizer_init` is True") + + return self.tokenizer + + def get_bos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for BOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).bos_token_id + + def get_eos_token_id(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + if self.tokenizer is None: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + return None + + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + + def get_decoder_start_token_id(self) -> Optional[int]: + ''' + Obtain the decoder start token id employed by an encoder/decoder + model. Returns None for non-encoder/decoder models or if the + model config is unavailable. + ''' + + if not self.is_encoder_decoder_model(): + print_warning_once("Using None for decoder start token id because " + "this is not an encoder/decoder model.") + return None + + if (self.model_config is None or self.model_config.hf_config is None): + print_warning_once("Using None for decoder start token id because " + "model config is not available.") + return None + + dec_start_token_id = getattr(self.model_config.hf_config, + 'decoder_start_token_id', None) + if dec_start_token_id is None: + print_warning_once("Falling back on for decoder start token " + "id because decoder start token id is not " + "available.") + dec_start_token_id = self.get_bos_token_id() + + return dec_start_token_id + + def _get_default_enc_dec_decoder_prompt(self) -> List[int]: + ''' + Specifically for encoder/decoder models: + generate a default decoder prompt for when + the user specifies only the encoder prompt. + + Encoder/decoder models utilize the decoder + prompt in different ways; as new models are + added, it is intended that this function + will be extended to produce differing + default decoder prompts, depending on the + model variety. + + Absent a special case, the default behavior + of this method is to mirror the behavior of + the HuggingFace (HF) GenerationMixin for a None + decoder prompt, which is to employ a logit processor + setting to force the first decoded token to be . + Here, this behavior is approximated by having the + "default" decoder prompt be . + + However, it is possible that in the future + other models may have different or more + complex logic for the default decoder prompt. + This motivates having a special helper method + for default decoder prompts. + + Returns: + + * prompt_token_ids + ''' + + bos_token_id = self.get_bos_token_id() + assert bos_token_id is not None + return [bos_token_id] + + def _prepare_decoder_input_ids_for_generation( + self, + decoder_input_ids: Optional[List[int]], + force_bos: bool = True, + ) -> List[int]: + """ + Prepares `decoder_input_ids` for generation with encoder-decoder models. + + Based on + + https://github.com/huggingface/transformers/blob/ + 4037a2b5b1278736e566aec12e169100275545ea/ + src/transformers/generation/utils.py + + specifically GenerationMixin._prepare_decoder_input_ids_for_generation() + + Arguments: + + * decoder_input_ids: input token ids to preprocess + + Returns: + + * Processed token list + """ + + decoder_start_token_id = self.get_decoder_start_token_id() + assert decoder_start_token_id is not None + + if decoder_input_ids is None: + # no decoder prompt input -> + # use decoder_start_token_id as decoder_input_ids + decoder_input_ids = self._get_default_enc_dec_decoder_prompt() + + if force_bos and (len(decoder_input_ids) == 0 + or decoder_input_ids[0] != decoder_start_token_id): + decoder_input_ids = [decoder_start_token_id] + decoder_input_ids + + return decoder_input_ids + + def _apply_prompt_adapter( + self, + prompt_token_ids: List[int], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> List[int]: + if prompt_adapter_request: + prompt_token_ids = ( + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + + prompt_token_ids) + + return prompt_token_ids + + def _tokenize_prompt( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """ + Apply the model's tokenizer to a text prompt, returning the + corresponding token IDs. + """ + tokenizer = self.get_tokenizer_group() + + return tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + async def _tokenize_prompt_async( + self, + prompt: str, + request_id: str, + lora_request: Optional[LoRARequest], + ) -> List[int]: + """Async version of :meth:`_tokenize_prompt`.""" + tokenizer = self.get_tokenizer_group() + + return await tokenizer.encode_async(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + + def _extract_prompt_components( + self, + prompt: SingletonPrompt, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + ''' + Extract the components of any single encoder or decoder input prompt. + + Arguments: + + * request_id + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + * mm_processor_kwargs (request-level input processor/mapper overrides) + ''' + + parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "str": + prompt_text = parsed["content"] + prompt_token_ids = self._tokenize_prompt( + prompt_text, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + mm_processor_kwargs = None + elif parsed["type"] == "tokens": + prompt_text = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + elif parsed["type"] == "text": + prompt_text = parsed["content"]["prompt"] + prompt_token_ids = self._tokenize_prompt( + prompt_text, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + else: + assert_never(parsed) + + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) + + async def _extract_prompt_components_async( + self, + prompt: SingletonPrompt, + request_id: str, + lora_request: Optional[LoRARequest] = None, + ) -> PromptComponents: + """Async version of :meth:`_extract_prompt_components`.""" + parsed = parse_singleton_prompt(prompt) + + if parsed["type"] == "str": + prompt_text = parsed["content"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = None + mm_processor_kwargs = None + elif parsed["type"] == "tokens": + prompt_text = None + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + elif parsed["type"] == "text": + prompt_text = parsed["content"]["prompt"] + prompt_token_ids = await self._tokenize_prompt_async( + prompt_text, + request_id=request_id, + lora_request=lora_request, + ) + multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + else: + assert_never(parsed) + + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) + + def _build_enc_dec_llm_inputs( + self, + encoder_comps: PromptComponents, + decoder_comps: DecoderPromptComponents, + mm_processor_kwargs: Dict[str, Any], + ) -> EncoderDecoderLLMInputs: + encoder_prompt, encoder_prompt_ids, encoder_mm_data, _ = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data, _ = decoder_comps + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + if decoder_mm_data is not None: + raise ValueError( + "Multi-modality decoder inputs of encoder-decoder models are " + "not supported yet") + + # For Multi-Modal models (e.g., mllama), the text input can be + # <|image|><|begin_of_text|>hello world. And we should not add + # another <|begin_of_text|> to the beginning. + decoder_prompt_ids = (self._prepare_decoder_input_ids_for_generation( + decoder_prompt_ids, + force_bos=(encoder_mm_data is None and decoder_mm_data is None))) + + return EncoderDecoderLLMInputs( + prompt_token_ids=decoder_prompt_ids, + prompt=decoder_prompt, + multi_modal_data=decoder_mm_data, + mm_processor_kwargs=mm_processor_kwargs, + encoder_prompt_token_ids=encoder_prompt_ids, + encoder_prompt=encoder_prompt, + encoder_multi_modal_data=encoder_mm_data, + ) + + def _process_encoder_decoder_prompt( + self, + prompt: PromptType, + request_id: str, + ) -> EncoderDecoderLLMInputs: + ''' + For encoder/decoder models only: + Process an input prompt into an + :class:`EncoderDecoderLLMInputs` instance. + + There are two types of input prompts: + singleton prompts which carry only the + encoder prompt, and explicit encoder/decoder + prompts which carry both the encoder and the + decoder prompts as member variables. + + This function handles the following scenarios: + * Singleton encoder prompt: extract encoder prompt + token ids & infer default decoder prompt token ids + * Explicit encoder/decoder prompt: extract encoder + and decoder prompt token ids + + Note that for Explicit encoder/decoder prompts, + each sub-prompt (encoder or decoder prompt) can + have any possible singleton type; thus this + method relies on helper functions to obtain + token ids for the sub-prompts. + + Arguments: + + * prompt: an input prompt + * request_id + + Returns: + + * :class:`EncoderDecoderLLMInputs` instance + ''' + + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(prompt): + encoder_comps = self._extract_prompt_components( + prompt["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := prompt["decoder_prompt"]) is None: + decoder_comps = None, None, None, None + else: + decoder_comps = self._extract_prompt_components( + decoder_input, + request_id=request_id, + ) + # Handle this carefully in case it was directly initialized by user + mm_processor_kwargs = prompt.get("mm_processor_kwargs", {}) + else: + encoder_comps = self._extract_prompt_components( + prompt, + request_id=request_id, + ) + # If there are no decoder components, we assume the + # mm_processor_kwargs are in the encoder prompt + mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ + -1] is not None else {} + decoder_comps = None, None, None, None + + return self._build_enc_dec_llm_inputs( + encoder_comps, + decoder_comps, + mm_processor_kwargs, + ) + + async def _process_encoder_decoder_prompt_async( + self, + prompt: PromptType, + request_id: str, + ) -> EncoderDecoderLLMInputs: + """Async version of :meth:`_process_encoder_decoder_prompt`.""" + encoder_comps: PromptComponents + decoder_comps: DecoderPromptComponents + + if is_explicit_encoder_decoder_prompt(prompt): + encoder_task = self._extract_prompt_components_async( + prompt["encoder_prompt"], + request_id=request_id, + ) + + if (decoder_input := prompt["decoder_prompt"]) is None: + encoder_comps = await encoder_task + decoder_comps = None, None, None, None + else: + decoder_task = self._extract_prompt_components_async( + decoder_input, + request_id=request_id, + ) + + encoder_comps, decoder_comps = await asyncio.gather( + encoder_task, decoder_task) + mm_processor_kwargs = prompt["mm_processor_kwargs"] + else: + encoder_comps = await self._extract_prompt_components_async( + prompt, + request_id=request_id, + ) + # If there are no decoder components, we assume the + # mm_processor_kwargs are in the encoder prompt + mm_processor_kwargs = encoder_comps[-1] if encoder_comps[ + -1] is not None else {} + decoder_comps = None, None, None, None + + return self._build_enc_dec_llm_inputs( + encoder_comps, + decoder_comps, + mm_processor_kwargs, + ) + + def _build_decoder_only_llm_inputs( + self, + prompt_comps: PromptComponents, + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> LLMInputs: + (prompt, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) = prompt_comps + + prompt_token_ids = self._apply_prompt_adapter( + prompt_token_ids, prompt_adapter_request=prompt_adapter_request) + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=prompt, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs) + + def _process_decoder_only_prompt( + self, + prompt: SingletonPrompt, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + ''' + For decoder-only models: + Process an input prompt into an :class:`LLMInputs` instance. + + Arguments: + + * prompt: input prompt + * request_id + * lora_request + * prompt_adapter_request + + Returns: + + * :class:`LLMInputs` instance + ''' + + prompt_comps = self._extract_prompt_components( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + async def _process_decoder_only_prompt_async( + self, + prompt: SingletonPrompt, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> LLMInputs: + """Async version of :meth:`_process_decoder_only_prompt`.""" + prompt_comps = await self._extract_prompt_components_async( + prompt, + request_id=request_id, + lora_request=lora_request, + ) + + return self._build_decoder_only_llm_inputs( + prompt_comps, + prompt_adapter_request=prompt_adapter_request, + ) + + def preprocess( + self, + prompt: PromptType, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Preprocess the input prompt.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return self._process_encoder_decoder_prompt( + prompt, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(prompt): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return self._process_decoder_only_prompt( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + async def preprocess_async( + self, + prompt: PromptType, + request_id: str, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> Union[LLMInputs, EncoderDecoderLLMInputs]: + """Async version of :meth:`preprocess`.""" + if self.is_encoder_decoder_model(): + # Encoder-decoder model requires special mapping of + # input prompts to encoder & decoder + return await self._process_encoder_decoder_prompt_async( + prompt, + request_id=request_id, + ) + + if is_explicit_encoder_decoder_prompt(prompt): + raise ValueError("Cannot pass encoder-decoder prompt " + "to decoder-only models") + + # Decoder-only operation + return await self._process_decoder_only_prompt_async( + prompt, + request_id=request_id, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + def is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py new file mode 100644 index 0000000..5bd3e1c --- /dev/null +++ b/vllm/inputs/registry.py @@ -0,0 +1,314 @@ +import functools +from collections import UserDict +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, + Protocol, Tuple, Type) + +from torch import nn +from transformers import PretrainedConfig +from typing_extensions import TypeVar + +from vllm.logger import init_logger +from vllm.utils import (get_allowed_kwarg_only_overrides, print_warning_once, + resolve_mm_processor_kwargs) + +from .data import LLMInputs + +if TYPE_CHECKING: + from vllm.config import ModelConfig + from vllm.multimodal import MultiModalDataDict, MultiModalRegistry + from vllm.sequence import SequenceData + +logger = init_logger(__name__) + +C = TypeVar("C", bound=PretrainedConfig, default=PretrainedConfig) + + +@dataclass(frozen=True) +class InputContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + def get_hf_config(self, hf_config_type: Type[C] = PretrainedConfig) -> C: + """ + Get the HuggingFace configuration + (:class:`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the model is not of the specified type. + """ + + hf_config = self.model_config.hf_config + if not isinstance(hf_config, hf_config_type): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {hf_config_type}, but " + f"found type: {type(hf_config)}") + + return hf_config + + def get_hf_image_processor_config(self) -> Dict[str, Any]: + """ + Get the HuggingFace image processor configuration of the model. + """ + + return self.model_config.hf_image_processor_config + + +N = TypeVar("N", bound=Type[nn.Module]) + + +class DummyDataFactory(Protocol): + + def __call__( + self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + **mm_processor_kwargs: Any, + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + """ + Create dummy data to be inputted into the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + + The :code:`mm_processor_kwargs` are overrides provided at + initialization time to values in the config whose values + may affect the number of tokens per instance. + """ + ... + + +class _MultiModalCounts(UserDict): + """ + Wraps `mm_counts` for a more informative error message + when attempting to access a plugin that does not exist. + """ + + def __getitem__(self, key: str) -> int: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"There is no multi-modal plugin with the key: {key}. " + f"Available keys: {set(self.keys())}") + raise KeyError(msg) from exc + + +InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs] +"""Preprocess the inputs to the model.""" + + +class InputRegistry: + """ + A registry to dispatch data processing + according to the target model. + """ + + def __init__(self) -> None: + self._dummy_factories_by_model_type: Dict[Type[nn.Module], + DummyDataFactory] = {} + self._dummy_encoder_factories_by_model_type: Dict[ + Type[nn.Module], DummyDataFactory] = {} + self._input_processors_by_model_type: Dict[Type[nn.Module], + InputProcessor] = {} + + def _default_dummy_data_factory( + self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + """ + The default dummy data factory represents the longest possible text + that can be inputted to the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + """ + # Avoid circular import + from vllm.sequence import SequenceData + + dummy_seq_data = SequenceData.from_token_counts((0, seq_len)) + dummy_multi_modal_data = None + + return dummy_seq_data, dummy_multi_modal_data + + def register_dummy_data(self, factory: DummyDataFactory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The resulting memory usage + should be an upper bound of what the model would use at inference time. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_factories_by_model_type: + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_data_factory(self, model_cls: Type[nn.Module]): + return self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + def register_dummy_encoder_data(self, factory: DummyDataFactory): + """ + Register a dummy encoder data factory to a model class + + This is similar to :meth:`~register_dummy_data`, but for encoder input. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_encoder_factories_by_model_type: + logger.warning( + "Model class %s already has dummy encoder data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_encoder_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def _get_dummy_encoder_data_factory(self, model_cls: Type[nn.Module]): + return self._dummy_encoder_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + def dummy_data_for_profiling( + self, + model_config: "ModelConfig", + seq_len: int, + mm_registry: "MultiModalRegistry", + is_encoder_data: bool = False, + ) -> Tuple["SequenceData", Optional["MultiModalDataDict"]]: + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + + See also: + :ref:`enabling_multimodal_inputs` + + Note: + This should be called after + :meth:`~MultiModalRegistry.init_mm_limits_per_prompt`. + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + if is_encoder_data: + dummy_factory = self._get_dummy_encoder_data_factory(model_cls) + else: + dummy_factory = self._get_dummy_data_factory(model_cls) + mm_counts = mm_registry.get_mm_limits_per_prompt(model_config) + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + dummy_factory, overrides=model_config.mm_processor_kwargs) + + seq_data, mm_data = dummy_factory(InputContext(model_config), seq_len, + _MultiModalCounts(mm_counts), + **mm_processor_kwargs) + + # Having more tokens is over-conservative but otherwise fine + num_tokens = seq_data.prompt_token_ids + if len(num_tokens) < seq_len: + if is_encoder_data: + print_warning_once( + f"Expected at least {seq_len} dummy encoder tokens for " + f"profiling, but found {len(num_tokens)} tokens instead.") + else: + raise AssertionError( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but found {len(num_tokens)} tokens instead.") + if mm_data is not None: + for k, v in mm_data.items(): + num_items = len(v) if isinstance(v, list) else 1 + num_expected = mm_counts[k] + assert num_items >= num_expected, ( + f"Expected at least {num_expected} dummy '{k}' instances " + f"for profiling, but found {num_items} instances instead.") + + return seq_data, mm_data + + def _default_input_processor(self, ctx: InputContext, + inputs: LLMInputs) -> LLMInputs: + """The default input processor is a no-op.""" + return inputs + + def register_input_processor(self, processor: InputProcessor): + """ + Register an input processor to a model class. + + The provided function is invoked on each input to the model. This + happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`. + + See also: + :ref:`input_processing_pipeline` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_processors_by_model_type: + logger.warning( + "Model class %s already has input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors_by_model_type[model_cls] = processor + + return model_cls + + return wrapper + + def _get_model_input_processor(self, model_cls: Type[nn.Module]): + return self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + + def process_input(self, model_config: "ModelConfig", + inputs: LLMInputs) -> LLMInputs: + """ + Apply an input processor to an instance of model inputs. + + The model is identified by ``model_config``. + + See also: + :ref:`input_processing_pipeline` + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + processor = self._get_model_input_processor(model_cls) + + # Handle multimodal processor kwargs with priority: + # Inference kwargs -> Init kwargs -> {} + # If it's empty, it'll fall back to the default kwarg values + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + inputs.get("mm_processor_kwargs"), + processor, + ) + + return processor(InputContext(model_config), inputs, + **mm_processor_kwargs) + + def create_input_processor(self, model_config: "ModelConfig"): + """ + Create an input processor (see :meth:`_process_input`) for a + specific model. + """ + return functools.partial(self.process_input, model_config) diff --git a/vllm/logger.py b/vllm/logger.py new file mode 100644 index 0000000..77dddbf --- /dev/null +++ b/vllm/logger.py @@ -0,0 +1,155 @@ +"""Logging configuration for vLLM.""" +import datetime +import json +import logging +import os +import sys +from functools import partial +from logging import Logger +from logging.config import dictConfig +from os import path +from typing import Dict, Optional + +import vllm.envs as envs + +VLLM_CONFIGURE_LOGGING = envs.VLLM_CONFIGURE_LOGGING +VLLM_LOGGING_CONFIG_PATH = envs.VLLM_LOGGING_CONFIG_PATH +VLLM_LOGGING_LEVEL = envs.VLLM_LOGGING_LEVEL + +_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + +DEFAULT_LOGGING_CONFIG = { + "formatters": { + "vllm": { + "class": "vllm.logging.NewLineFormatter", + "datefmt": _DATE_FORMAT, + "format": _FORMAT, + }, + }, + "handlers": { + "vllm": { + "class": "logging.StreamHandler", + "formatter": "vllm", + "level": VLLM_LOGGING_LEVEL, + "stream": "ext://sys.stdout", + }, + }, + "loggers": { + "vllm": { + "handlers": ["vllm"], + "level": "DEBUG", + "propagate": False, + }, + }, + "version": 1, + "disable_existing_loggers": False +} + + +def _configure_vllm_root_logger() -> None: + logging_config: Optional[Dict] = None + + if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: + raise RuntimeError( + "VLLM_CONFIGURE_LOGGING evaluated to false, but " + "VLLM_LOGGING_CONFIG_PATH was given. VLLM_LOGGING_CONFIG_PATH " + "implies VLLM_CONFIGURE_LOGGING. Please enable " + "VLLM_CONFIGURE_LOGGING or unset VLLM_LOGGING_CONFIG_PATH.") + + if VLLM_CONFIGURE_LOGGING: + logging_config = DEFAULT_LOGGING_CONFIG + + if VLLM_LOGGING_CONFIG_PATH: + if not path.exists(VLLM_LOGGING_CONFIG_PATH): + raise RuntimeError( + "Could not load logging config. File does not exist: %s", + VLLM_LOGGING_CONFIG_PATH) + with open(VLLM_LOGGING_CONFIG_PATH, encoding="utf-8", + mode="r") as file: + custom_config = json.loads(file.read()) + + if not isinstance(custom_config, dict): + raise ValueError("Invalid logging config. Expected Dict, got %s.", + type(custom_config).__name__) + logging_config = custom_config + + if logging_config: + dictConfig(logging_config) + + +def init_logger(name: str) -> Logger: + """The main purpose of this function is to ensure that loggers are + retrieved in such a way that we can be sure the root vllm logger has + already been configured.""" + + return logging.getLogger(name) + + +# The root logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_configure_vllm_root_logger() + +logger = init_logger(__name__) + + +def _trace_calls(log_path, root_dir, frame, event, arg=None): + if event in ['call', 'return']: + # Extract the filename, line number, function name, and the code object + filename = frame.f_code.co_filename + lineno = frame.f_lineno + func_name = frame.f_code.co_name + if not filename.startswith(root_dir): + # only log the functions in the vllm root_dir + return + # Log every function call or return + try: + last_frame = frame.f_back + if last_frame is not None: + last_filename = last_frame.f_code.co_filename + last_lineno = last_frame.f_lineno + last_func_name = last_frame.f_code.co_name + else: + # initial frame + last_filename = "" + last_lineno = 0 + last_func_name = "" + with open(log_path, 'a') as f: + if event == 'call': + f.write(f"{datetime.datetime.now()} Call to" + f" {func_name} in {filename}:{lineno}" + f" from {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + else: + f.write(f"{datetime.datetime.now()} Return from" + f" {func_name} in {filename}:{lineno}" + f" to {last_func_name} in {last_filename}:" + f"{last_lineno}\n") + except NameError: + # modules are deleted during shutdown + pass + return partial(_trace_calls, log_path, root_dir) + + +def enable_trace_function_call(log_file_path: str, + root_dir: Optional[str] = None): + """ + Enable tracing of every function call in code under `root_dir`. + This is useful for debugging hangs or crashes. + `log_file_path` is the path to the log file. + `root_dir` is the root directory of the code to trace. If None, it is the + vllm root directory. + + Note that this call is thread-level, any threads calling this function + will have the trace enabled. Other threads will not be affected. + """ + logger.warning( + "VLLM_TRACE_FUNCTION is enabled. It will record every" + " function executed by Python. This will slow down the code. It " + "is suggested to be used for debugging hang or crashes only.") + logger.info("Trace frame log is saved to %s", log_file_path) + if root_dir is None: + # by default, this is the vllm root directory + root_dir = os.path.dirname(os.path.dirname(__file__)) + sys.settrace(partial(_trace_calls, log_file_path, root_dir)) diff --git a/vllm/logging/__init__.py b/vllm/logging/__init__.py new file mode 100644 index 0000000..b9aec38 --- /dev/null +++ b/vllm/logging/__init__.py @@ -0,0 +1,5 @@ +from vllm.logging.formatter import NewLineFormatter + +__all__ = [ + "NewLineFormatter", +] diff --git a/vllm/logging/__pycache__/__init__.cpython-310.pyc b/vllm/logging/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..f37e9d2 Binary files /dev/null and b/vllm/logging/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/logging/__pycache__/formatter.cpython-310.pyc b/vllm/logging/__pycache__/formatter.cpython-310.pyc new file mode 100644 index 0000000..afb951e Binary files /dev/null and b/vllm/logging/__pycache__/formatter.cpython-310.pyc differ diff --git a/vllm/logging/formatter.py b/vllm/logging/formatter.py new file mode 100644 index 0000000..b24b4e1 --- /dev/null +++ b/vllm/logging/formatter.py @@ -0,0 +1,15 @@ +import logging + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, style="%"): + logging.Formatter.__init__(self, fmt, datefmt, style) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/__pycache__/__init__.cpython-310.pyc b/vllm/lora/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..73ac80f Binary files /dev/null and b/vllm/lora/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/fully_sharded_layers.cpython-310.pyc b/vllm/lora/__pycache__/fully_sharded_layers.cpython-310.pyc new file mode 100644 index 0000000..652738e Binary files /dev/null and b/vllm/lora/__pycache__/fully_sharded_layers.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/layers.cpython-310.pyc b/vllm/lora/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000..bcc2b16 Binary files /dev/null and b/vllm/lora/__pycache__/layers.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/lora.cpython-310.pyc b/vllm/lora/__pycache__/lora.cpython-310.pyc new file mode 100644 index 0000000..7a5f001 Binary files /dev/null and b/vllm/lora/__pycache__/lora.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/models.cpython-310.pyc b/vllm/lora/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000..0b47ee7 Binary files /dev/null and b/vllm/lora/__pycache__/models.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/punica.cpython-310.pyc b/vllm/lora/__pycache__/punica.cpython-310.pyc new file mode 100644 index 0000000..1b72815 Binary files /dev/null and b/vllm/lora/__pycache__/punica.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/request.cpython-310.pyc b/vllm/lora/__pycache__/request.cpython-310.pyc new file mode 100644 index 0000000..fba3fa2 Binary files /dev/null and b/vllm/lora/__pycache__/request.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/utils.cpython-310.pyc b/vllm/lora/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..cf125f8 Binary files /dev/null and b/vllm/lora/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/lora/__pycache__/worker_manager.cpython-310.pyc b/vllm/lora/__pycache__/worker_manager.cpython-310.pyc new file mode 100644 index 0000000..f99a173 Binary files /dev/null and b/vllm/lora/__pycache__/worker_manager.cpython-310.pyc differ diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 0000000..a7887a0 --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,343 @@ +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA) + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs["lora_config"].fully_sharded_loras) + + return dec + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device, + ) + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + self.punica_wrapper.add_expand(output, + buffer, + self.lora_b_stacked, + add_input=True) + # now have column partitioned output + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): + """ + MergedColumnParallelLinearWithShardedLoRA and + MergedQKVParallelLinearWithShardedLora share the same + LoRa weight application method. + + The main difference is the step by shard_size for lora_b which can + vary for MergedQKVParallelLinearWithShardedLora but is constant for + MergedColumnParallelLinearWithShardedLoRA. + """ + # expecting 2 for column parallel and 3 for qkv + n = len(layer.lora_a_stacked) + output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffers = torch.zeros( + (n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device, + ) + for idx in range(n): + layer.punica_wrapper.add_shrink(buffers[idx], x, + layer.lora_a_stacked[idx], 1.0) + + buffers = tensor_model_parallel_all_gather(buffers) + left_offset = 0 + for idx in range(n): + shard_size = layer.lora_b_stacked[idx].shape[2] + layer.punica_wrapper.add_expand_slice( + output, + buffers[idx], + layer.lora_b_stacked[idx], + left_offset, + shard_size, + add_input=True, + ) + left_offset += shard_size + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None: + return lora_a + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[0][:, + output_start_idx:output_start_idx + output_shard_size], + lora_a[1][:, + output_start_idx:output_start_idx + output_shard_size], + ] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + self.punica_wrapper.add_expand(output, + buffer, + self.lora_b_stacked, + add_input=True) + # now have column partitioned output + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from MergedQKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None: + return lora_a + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]], + lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]], + lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]], + ] + return lora_a + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros( + (x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device, + ) + + self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + self.punica_wrapper.add_expand_slice(output, buffer, + self.lora_b_stacked, start_idx, + shard_size) + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000..6254c67 --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,1312 @@ +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import LoRAConfig +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather) +from vllm.distributed.utils import divide +from vllm.lora.punica import PunicaWrapper +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.rotary_embedding import ( + LinearScalingRotaryEmbedding, RotaryEmbedding) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +if TYPE_CHECKING: + pass + + +def _get_lora_device(base_layer: nn.Module) -> torch.device: + # code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34 + """Returns the device for where to place the LoRA tensors.""" + # unquantizedLinear + if hasattr(base_layer, "weight"): + return base_layer.weight.device + # Compressed Tensor + elif hasattr(base_layer, "weight_packed"): + return base_layer.weight_packed.device + # GPTQ/AWQ + elif hasattr(base_layer, "qweight"): + return base_layer.qweight.device + # marlin + elif hasattr(base_layer, "B"): + return base_layer.B.device + else: + raise ValueError(f"Unsupported base layer: {base_layer}") + + +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop("decorate") if "decorate" in kwargs else True + condition = (not kwargs["lora_config"].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + +@dataclass +class LoRAMapping(AdapterMapping): + is_prefill: bool = False + + +class BaseLayerWithLoRA(nn.Module): + + def slice_lora_a( + self, lora_a: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b( + self, lora_b: Union[torch.Tensor, List[Union[torch.Tensor, None]]] + ) -> Union[torch.Tensor, List[Union[torch.Tensor, None]]]: + """Slice lora b if splitting with tensor parallelism.""" + ... + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + punica_wrapper: PunicaWrapper, + ): + self.punica_wrapper: PunicaWrapper = punica_wrapper + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + raise NotImplementedError + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.embeddings_slice: Optional[Tuple[int, int]] + self.embeddings_weights: Optional[torch.Tensor] + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + if self.base_layer.num_added_embeddings_per_partition > 0: + # We can start adding lora weights + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2], + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + assert self.embeddings_weights is not None + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + embeddings_indices = self.punica_wrapper.embeddings_indices + indices = embeddings_indices[1].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = embeddings_indices[0].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], + -1, + ) + + # Embedding layer only need expand op + self.punica_wrapper.add_expand(full_output, + full_lora_a_embeddings, + self.lora_b_stacked, + add_input=True) + return full_output.view_as(full_output_org) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is VocabParallelEmbedding + + +class ReplicatedLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ReplicatedLinear) -> None: + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + lora_a_output_size = lora_config.max_lora_rank + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_a_output_size, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + return output + + def forward(self, input_): + """Forward of ReplicatedLinearWithLoRA + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output = self.apply(input_, bias) + + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ReplicatedLinear + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + + LoRA B is sliced for tensor parallelism. + """ + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size = self.base_layer.input_size + self.output_size = self.base_layer.output_size_per_partition + self.device = _get_lora_device(self.base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.output_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.output_dim = self.lora_b_stacked.shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 1) + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.output_size // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ) for _ in range(n_slices)) + + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + if lora_b[0] is None or lora_b[1] is None: + return lora_b + shard_size = self.output_dim + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = [ + lora_b[0][:, start_idx:end_idx], + lora_b[1][:, start_idx:end_idx], + ] + return lora_b + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora_packed_nslice( + output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, + (self.output_dim, self.output_dim)) + return output + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is MergedColumnParallelLinear + and len(packed_modules_list) == 2) + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ + ColumnParallelLinear layer that is specifically designed for + qkv_proj. Certain models, such as chtglm3 and baichuan-7b, + only contains a single LoRA within their qkv_proj layer. + + During inference with Tensor Parallel, the weights of lora_b + must be accurately partitioned according to the respective ranks. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + self.tp_size = get_tensor_model_parallel_world_size() + self.q_proj_total_size = (self.base_layer.total_num_heads * + self.base_layer.head_size) + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.kv_proj_total_size = (self.base_layer.total_num_kv_heads * + self.base_layer.head_size) + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + lora_b_q = lora_b[:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + k_offset = self.q_proj_total_size + lora_b_k = lora_b[:, k_offset + + self.kv_proj_shard_size * self.kv_shard_id:k_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + v_offset = k_offset + self.kv_proj_total_size + lora_b_v = lora_b[:, v_offset + + self.kv_proj_shard_size * self.kv_shard_id:v_offset + + self.kv_proj_shard_size * (self.kv_shard_id + 1)] + lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=1) + return lora_b + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + return type(source_layer) is QKVParallelLinear and len( + packed_modules_list) == 1 + + +class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) + # q, k, v + self.lora_a_stacked = ( + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ), + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ), + torch.zeros( + max_loras, + 1, + lora_a_output_size_per_partition, + self.input_size, + dtype=lora_config.lora_dtype, + device=self.device, + ), + ) + self.lora_b_stacked = ( + torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.device, + ), + ) + + self.output_slices = ( + self.q_proj_shard_size, + self.kv_proj_shard_size, + self.kv_proj_shard_size, + ) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + # lazily initialized. + self.indices: torch.Tensor + self.indices_len: List[int] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def slice_lora_a( + self, lora_a: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + return lora_a + + def slice_lora_b( + self, lora_b: List[Union[torch.Tensor, None]] + ) -> List[Union[torch.Tensor, None]]: + lora_b_q, lora_b_k, lora_b_v = None, None, None + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1), ] + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1), ] + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1), ] + lora_b = [lora_b_q, lora_b_k, lora_b_v] + return lora_b + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + if lora_b[0] is not None: + lora_b_q = lora_b[0] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + self.punica_wrapper.add_lora_packed_nslice(output, x, + self.lora_a_stacked, + self.lora_b_stacked, 1.0, + self.output_slices) + return output + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return (type(source_layer) is QKVParallelLinear + and len(packed_modules_list) == 3) + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + self.input_size = self.base_layer.input_size_per_partition + self.output_size = self.base_layer.output_size + self.device = _get_lora_device(self.base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_config = lora_config + self.tp_rank = get_tensor_model_parallel_rank() + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.input_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + tp_size = get_tensor_model_parallel_world_size() + lora_b_output_size_per_partition = ( + self.output_size if not lora_config.fully_sharded_loras else + divide(self.output_size, tp_size)) + + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + lora_b_output_size_per_partition, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.base_layer.tp_size > 1: + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) + self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, + self.lora_b_stacked, 1.0) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return (self.base_layer.weight if hasattr(self.base_layer, "weight") + else self.base_layer.qweight) + + @classmethod + @_not_fully_sharded_can_replace + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + return type(source_layer) is RowParallelLinear + + +class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. + + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[List[int]]) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping + + @property + def logits_as_input(self): + return self.base_layer.logits_as_input + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def scale(self): + return self.base_layer.scale + + @property + def soft_cap(self): + return self.base_layer.soft_cap + + @property + def use_gather(self): + return self.base_layer.use_gather + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + @property + def should_modify_greedy_probs_inplace(self): + return self.base_layer.should_modify_greedy_probs_inplace + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # TODO: Verify if this condition can be further relaxed + if 32000 < self.base_layer.vocab_size > 257024: + raise ValueError("When using LoRA, vocab size must be " + "32000 >= vocab_size <= 257024") + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.linear_method.apply(lm_head, hidden_states) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + if logits is None: + return None + + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + indices_padded = self.punica_wrapper.sampler_indices_padded + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1], ] = lora_logits + + # LogitsProcessorWithLoRA always using bgmv + self.punica_wrapper.add_lora_logits(logits, hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, 1.0) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + # Special handling for the LogitsProcessor. + return False + + +class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA): + """Implements RoPE-scaled embeddings with linear scaling for + multiple LoRA adapters with a specialized kernel. + + Replace LinearScalingRotaryEmbedding with MultiLinearScalingRotaryEmbedding + which can handle multi lora adapters in a specialied kernel. + """ + + def __init__(self, base_layer: RotaryEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + @property + def scaling_factors(self): + return self.base_layer.scaling_factors + + @property + def rotary_dim(self): + return self.base_layer.rotary_dim + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + scaling_factors = (list(lora_config.long_lora_scaling_factors) + if lora_config.long_lora_scaling_factors else []) + base_scaling_factor = (self.base_layer.scaling_factor if isinstance( + self.base_layer, LinearScalingRotaryEmbedding) else 1.0) + scaling_factors = sorted( + list(set([base_scaling_factor] + scaling_factors))) + self.base_layer = LinearScalingRotaryEmbedding( + self.base_layer.head_size, + self.base_layer.rotary_dim, + self.base_layer.max_position_embeddings, + self.base_layer.base, + self.base_layer.is_neox_style, + scaling_factors, + self.base_layer.dtype, + ) + + def reset_lora(self, index: int): + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + ... + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.base_layer( + positions, + query, + key, + offsets=self.punica_wrapper.long_lora_indices, + ) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self.base_layer.scaling_factor_to_offset + + @classmethod + def can_replace_layer( + cls, + source_layer: nn.Module, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig], + ) -> bool: + """Returns True if the layer can be replaced by this LoRA layer.""" + return (type(source_layer) is LinearScalingRotaryEmbedding + or type(source_layer) is RotaryEmbedding) + + def extra_repr(self) -> str: + return self.base_layer.extra_repr() diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000..14081b5 --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,169 @@ +from typing import List, Optional +from typing import Sequence as GenericSequence + +import torch +import torch.types + +from vllm.utils import is_pin_memory_available + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return self + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.types.Device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and is_pin_memory_available() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[Optional[int]], + lora_a: List[Optional[torch.Tensor]], + lora_b: List[Optional[torch.Tensor]], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, # type: ignore + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ # type: ignore + lora_alpha / self.rank # type: ignore # noqa + for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack( + cls, loras: GenericSequence[Optional["LoRALayerWeights"]] + ) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[ + 1 if lora is not None else None # type: ignore + for lora in loras + ]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore + continue + self.lora_b[i] *= self.scaling[i] # type: ignore + self.scaling[i] = 1 # type: ignore + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000..aaadca9 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,747 @@ +import copy +import json +import math +import os +import re +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Type + +import safetensors.torch +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.layers import (BaseLayerWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LoRAMapping) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.punica import PunicaWrapper +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + is_regex_target_modules, + parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models import SupportsLoRA, supports_multimodal +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.utils import PPMissingLayer +from vllm.utils import is_pin_memory_available + +logger = init_logger(__name__) + +_GLOBAL_LORA_ID = 0 + + +@dataclass +class LongContextLoRAContext: + """Context for lora adapters that support long context.""" + # The scaling factors to support long context lora fine tuned models. + scaling_factors: List[float] + # dimension to apply rotary embedding. + rot_dim: int + # offsets to the sin_cos_cache for each lora_id loaded. + # This value is dynamically modified. + offsets_by_lora_id: Dict[int, int] = field(default_factory=dict) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel(AdapterModel): + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRALayerWeights], + scaling_factor: Optional[float] = None, + ) -> None: + """ + Args: + lora_model_id: The integer id for the lora model. + rank: lora rank. + loras: module name -> weights for lora-replaced layers. + scaling_factor: Scaling factor to support long context lora model. + None if the lora is not tuned for long context support. + """ + self.id = lora_model_id + # Scaling factor for long context lora model. None if it is not + # fine tuned for the long context. + self.scaling_factor = scaling_factor + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRALayerWeights] = loras + + def clone(self, lora_model_id: int) -> "LoRAModel": + """Return a copy of the object with different ids. + + Will share the underlying tensors.""" + return self.__class__( + lora_model_id, + rank=self.rank, + loras=self.loras.copy(), + ) + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + scaling_factor: Optional[float] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and is_pin_memory_available() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + assert embedding_modules is not None + embeddings_module = next( + (k for k in embedding_modules if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights(module_name, rank, + lora_alpha, None, None, + lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + assert embedding_padding_modules is not None + if any(name in module_name + for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + expected_lora_modules: List[str], + *, + max_position_embeddings: Optional[int] = None, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint. + + Args: + lora_dir: The local path that has lora data. + expected_lora_modules: Name of modules that are expected to be + replaced by lora. + max_position_embeddings: Max position embedding length. Used to + scaling the largest context length. If None, the lora model's + context length is not scaled. + lora_model_id: Lora model id. If not given, automatically set by + a global counter. + device: Device where the lora model is loaded. + dtype: dtype of the lora model weights. + + Returns: + Loaded LoRA Model. + """ + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + with open(lora_config_path) as f: + config = json.load(f) + if os.path.isfile(lora_tensor_path): + tensors: Dict[str, torch.Tensor] = {} + # Find unexpected modules. + # Use safetensor key as a source of truth to find expected modules. + # in peft if you have target_modules A, B, C and C does not exist + # in the model it won’t error and model will be trained with A, B + # loraified. C won’t exist in the safetensor but it will exist in + # the target_modules of the adapter_config.json. + unexpected_modules = [] + with safetensors.safe_open(lora_tensor_path, + framework="pt") as f: # type: ignore + for lora_module in f.keys(): # noqa + module_name, _ = parse_fine_tuned_lora_name(lora_module) + part_name = module_name.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module_name) + if unexpected_modules: + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct" + ) + # Load tensors if there are only expected modules. + for module in f.keys(): # noqa + tensors[module] = f.get_tensor(module) + elif os.path.isfile(lora_bin_file_path): + # When a bin file is provided, we rely on config to find unexpected + # modules. + unexpected_modules = [] + target_modules = config["target_modules"] + if not isinstance(target_modules, list): + target_modules = [target_modules] + for module in target_modules: + # Compatible with more modules, + # such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: + unexpected_modules.append(module) + # loaded lora's target modules must be a subset of + # expected_lora_modules. It is not reliable. See + # https://github.com/vllm-project/vllm/pull/5909. But there's no + # other better mechanism. + if unexpected_modules and not is_regex_target_modules( + config["target_modules"], expected_lora_modules): + raise ValueError( + f"While loading {lora_dir}, expected" + f" target modules in {expected_lora_modules}" + f" but received {unexpected_modules}." + f" Please verify that the loaded LoRA module is correct") + tensors = torch.load(lora_bin_file_path, map_location=device) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path, + map_location=device) + + rank = config["r"] + lora_alpha = config["lora_alpha"] + context_length = config.get("context_length", None) + scaling_factor = None + if context_length: + if max_position_embeddings is None: + max_position_embeddings = context_length + scaling_factor = float( + math.ceil(context_length / max_position_embeddings)) + + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + scaling_factor=scaling_factor, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + ) + + +class LoRAModelManager(AdapterModelManager): + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: SupportsLoRA, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.long_lora_context: Optional[LongContextLoRAContext] = None + self.punica_wrapper = PunicaWrapper(max_num_batched_tokens, + max_batches=self.max_num_seqs, + device="cuda") + # Scaling factor -> offset to the sin_cos_cache to it. + # Used for long context lora. + self.scaling_factor_to_offset: Dict[float, int] = {} + super().__init__(model) + if hasattr(self.model, "supported_lora_modules"): + self.supported_lora_modules = copy.deepcopy( + self.model.supported_lora_modules) + if lora_config.long_lora_scaling_factors: + # We need to replace rotary emb layer to do batch computation + # for long lora. + self.supported_lora_modules.append("rotary_emb") + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) + # Used to indicate whether the model is a multimodal model + self.supports_mm: bool = ( + supports_multimodal(self.model) + # In case the model only supports LoRA for + # text modules (e.g. ChatGLM) + and hasattr(self.model, "get_mm_mapping")) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "BaseLayerWithLoRA"] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._last_mapping: Optional[LoRAMapping] = None + self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRa' + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + @property + def adapter_slots(self) -> int: + return self.lora_slots + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_adapters: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] + logger.debug("Activating LoRA. int id: %d, slot index: %d", + lora_model.id, index) + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_adapter(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def _set_long_lora_context(self, lora: LoRAModel): + if self.long_lora_context is None: + return + + if lora.scaling_factor is None: + return + + if (lora.scaling_factor not in self.scaling_factor_to_offset): + raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}" + " has not been initialized.") + + offsets = self.scaling_factor_to_offset.get(lora.scaling_factor) + if offsets: + self.long_lora_context.offsets_by_lora_id[lora.id] = offsets + + def _add_adapter(self, lora: LoRAModel): + self._create_merged_loras_inplace(lora) + self._registered_adapters[lora.id] = lora + self._set_long_lora_context(lora) + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in LoRAModelManager." + "Use LRUCacheLoRAModelManager for pinning") # type: ignore + + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: + # update lora states + self.punica_wrapper.update_metadata( + mapping, + self.lora_index_to_id, + self.lora_slots + 1, + self.vocab_size, + self.lora_config.lora_extra_vocab_size, + self.long_lora_context, + ) + + def remove_all_adapters(self): + """Remove all LoRAModels from the manager.""" + self._registered_adapters.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_adapters.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if isinstance(module, PPMissingLayer): + continue + if not self._match_target_modules(module_name): + continue + # A temporary approach for multimodal models to support LoRA + # TODO: Remove this restriction + if self._filter_unsupported_mm_module(module_name): + logger.warning( + "Regarding multimodal models, vLLM currently only supports " + "adding LoRA to language model, %s will be ignored.", + module_name, + ) + continue + parts = module_name.split(".")[-1] + packed_moduled_lst = self.packed_modules_mapping.get(parts, []) + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + packed_moduled_lst, self.model.config)) + + # LinearScalingRotaryEmbeddingWithLora is used to handle + # long context lora. Register relevant metadata. + if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora): + self.long_lora_context = LongContextLoRAContext( + new_module.scaling_factors, new_module.rotary_dim) + self.scaling_factor_to_offset = \ + new_module.scaling_factor_to_offset + # (yard1): TODO make this more robust + if "lm_head" in module_name: + logits_processor_module = self.model.get_submodule( + "logits_processor") + new_module = replace_submodule( + self.model, "logits_processor", + from_layer_logits_processor(logits_processor_module, + module, self.lora_slots, + self.lora_config, + self.model.config)) + + # In some models, especially multimodal ones, layers with the same + # name may have different types, such as nn.Linear and + # ReplicatedLinear. The nn.Linear layers cannot be replaced with + # LoRA layers, leading to assertion error. The following check + # aims to prevent this error + if self.supports_mm and not isinstance(new_module, + BaseLayerWithLoRA): + continue + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + # All lora layers share the same punica_wrapper based on reference. + new_module.set_mapping(self.punica_wrapper) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora( + self, + lora_id: int, + rank: int, + scaling_factor: Optional[float], + embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}, scaling_factor) + for module_name, module in self.model.named_modules(): + if (not self._match_target_modules(module_name) + or not isinstance(module, BaseLayerWithLoRA) + or isinstance(module, LinearScalingRotaryEmbeddingWithLora) + or self._filter_unsupported_mm_module(module_name)): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + assert embedding_modules is not None + if parts[-1] in embedding_modules: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked.shape[-1], + module.lora_b_stacked.shape[-2], + rank, + module.lora_a_stacked.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras: List[Optional["LoRALayerWeights"]] = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.supported_lora_modules) + + def _filter_unsupported_mm_module(self, module_name: str) -> bool: + """ + Regarding multimodal models, vLLM currently only supports adding LoRA to + language model. LoRA for other modules, such as the vision tower, will + be filtered out. + """ + if self.supports_mm: + prefix = module_name.split(".")[0] + module_mapping: MultiModelKeys = self.model.get_mm_mapping() + return (prefix in module_mapping.connector + or prefix in module_mapping.tower_model) + return False + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name, []) + # When replacements is less than or equal to 1, it indicates that this + # module is not a packed module. + if len(replacements) <= 1: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras: List[Optional[LoRALayerWeights]] = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], + bool]): + super().__init__(capacity, deactivate_lora_fn) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(lora.id) + was_added = False + return was_added + + def activate_adapter( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(lora_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, lora_id: int) -> bool: + """Pin a LoRAModel in the manager cache.""" + self._pin_lora_in_cpu_cache(lora_id) + self._pin_lora_in_gpu_cache(lora_id) + return True + + def _pin_lora_in_cpu_cache(self, lora_id: int): + try: + self._registered_adapters.pin(lora_id) + except ValueError as err: + raise ValueError("Pinning failed. " + f"LoRA {lora_id} is not registered.") from err + + def _pin_lora_in_gpu_cache(self, lora_id: int): + if lora_id not in self._active_adapters: + # move lora to gpu if not already active + self.activate_adapter(lora_id) + + self._active_adapters.pin(lora_id) + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not hasattr(model, "supported_lora_modules"): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + **kwargs) + return lora_manager diff --git a/vllm/lora/ops/__init__.py b/vllm/lora/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/ops/__pycache__/__init__.cpython-310.pyc b/vllm/lora/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9ede199 Binary files /dev/null and b/vllm/lora/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/bgmv_expand.cpython-310.pyc b/vllm/lora/ops/__pycache__/bgmv_expand.cpython-310.pyc new file mode 100644 index 0000000..217f22f Binary files /dev/null and b/vllm/lora/ops/__pycache__/bgmv_expand.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/bgmv_expand_slice.cpython-310.pyc b/vllm/lora/ops/__pycache__/bgmv_expand_slice.cpython-310.pyc new file mode 100644 index 0000000..da16abe Binary files /dev/null and b/vllm/lora/ops/__pycache__/bgmv_expand_slice.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/bgmv_shrink.cpython-310.pyc b/vllm/lora/ops/__pycache__/bgmv_shrink.cpython-310.pyc new file mode 100644 index 0000000..0f98091 Binary files /dev/null and b/vllm/lora/ops/__pycache__/bgmv_shrink.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/sgmv_expand.cpython-310.pyc b/vllm/lora/ops/__pycache__/sgmv_expand.cpython-310.pyc new file mode 100644 index 0000000..08cfc9f Binary files /dev/null and b/vllm/lora/ops/__pycache__/sgmv_expand.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/sgmv_expand_slice.cpython-310.pyc b/vllm/lora/ops/__pycache__/sgmv_expand_slice.cpython-310.pyc new file mode 100644 index 0000000..fdc90b3 Binary files /dev/null and b/vllm/lora/ops/__pycache__/sgmv_expand_slice.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/sgmv_shrink.cpython-310.pyc b/vllm/lora/ops/__pycache__/sgmv_shrink.cpython-310.pyc new file mode 100644 index 0000000..e7f1bfa Binary files /dev/null and b/vllm/lora/ops/__pycache__/sgmv_shrink.cpython-310.pyc differ diff --git a/vllm/lora/ops/__pycache__/utils.cpython-310.pyc b/vllm/lora/ops/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..dd89854 Binary files /dev/null and b/vllm/lora/ops/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py new file mode 100644 index 0000000..6a32387 --- /dev/null +++ b/vllm/lora/ops/bgmv_expand.py @@ -0,0 +1,168 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + current_n_c = tl.max_contiguous(current_n, BLOCK_N) + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n_c[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _bgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + batches (int): batch size + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + batches = lora_indices_tensor.size(0) + config = get_lora_op_configs("expand", batches, N) + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return + + +try: + bgmv_expand = torch.library.custom_op("lora::bgmv_expand", + _bgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py new file mode 100644 index 0000000..73628fd --- /dev/null +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -0,0 +1,181 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_N: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's + performance + """ + pid_sn = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + offset_k = tl.arange(0, BLOCK_K) + offset_n = tl.arange(0, BLOCK_N) + if EVEN_K: + tiled_a = tl.load(input_ptr + cur_batch * xm_stride + + offset_k * xk_stride, ) # [BLOCK_K] + else: + tiled_a = tl.load( + input_ptr + cur_batch * xm_stride + offset_k * xk_stride, + mask=offset_k < K, + other=0, + ) # [BLOCK_K] + # N must be divisible by SPLIT_N + split_n_length = tl.cdiv(N, SPLIT_N) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + # sliding to next row-block + b_ptr = (lora_ptr + l0_stride * lora_index + + pid_sn * split_n_length * lora_k_stride) + c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length + + slice_offset * cn_stride) + + for n in range(0, split_n_length, BLOCK_N): + current_n = n + offset_n + b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :] + < K) + c_mask = current_n < split_n_length + tiled_b = tl.load( + b_ptr + current_n[:, None] * lora_k_stride + + offset_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + if ADD_INPUTS: + tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out + else: + accumulator = tl.sum(tiled_a * tiled_b, 1) + + tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _bgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch, An index of -1 means no lora should be + applied. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size + batches (int): batch size + add_inputs (bool, optional): Defaults to False. + """ + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_K = triton.next_power_of_2(K) + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + batches = lora_indices_tensor.size(0) + + config = get_lora_op_configs("expand", batches, N) + + grid = lambda META: ( + META["SPLIT_N"], + batches, + ) + _bgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_K=BLOCK_K, + EVEN_K=EVEN_K, + ADD_INPUTS=ADD_INPUTS, + CAST_TYPE=CAST_TYPE, + **config, + ) + return + + +try: + bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", + _bgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py new file mode 100644 index 0000000..0846ff3 --- /dev/null +++ b/vllm/lora/ops/bgmv_shrink.py @@ -0,0 +1,150 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from .utils import get_lora_op_configs + + +@triton.jit +def _bgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_indices, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's + performance + """ + pid_sk = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + + offset_n = tl.arange(0, BLOCK_N) + offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K + a_ptr = input_ptr + cur_batch * xm_stride + b_ptr = lora_ptr + l0_stride * lora_index + accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32) + for k in range(0, K, BLOCK_K * SPLIT_K): + current_k = k + offset_k + current_k_c = tl.max_contiguous(current_k, BLOCK_K) + tiled_a = tl.load( + a_ptr + current_k_c, + mask=current_k < K, + other=0.0, + ) # [BLOCK_K] + b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K) + + tiled_b = tl.load( + b_ptr + offset_n[:, None] * lora_k_stride + + current_k[None, :] * lora_n_stride, + mask=b_ptr_mask, + other=0.0, + ) # [BLOCK_N,BLOCK_K] + + accumulator += tl.sum(tiled_a * tiled_b, 1) + accumulator *= scaling + offset_cn = tl.arange(0, BLOCK_N) + c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride + c_mask = offset_cn < N + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _bgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + batches = lora_indices_tensor.size(0) + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_N = triton.next_power_of_2(N) + # First try to load optimal config from the file + config = get_lora_op_configs("bgmv_shrink", batches, K) + + grid = lambda META: ( + META["SPLIT_K"], + batches, + ) + _bgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_N=BLOCK_N, + **config, + ) + return + + +try: + bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", + _bgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py new file mode 100644 index 0000000..adb3ab5 --- /dev/null +++ b/vllm/lora/ops/sgmv_expand.py @@ -0,0 +1,204 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + The sgmv's expand triton kernel is based on GroupGEMM. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _sgmv_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return + + +try: + sgmv_expand = torch.library.custom_op("lora::sgmv_expand", + _sgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py new file mode 100644 index 0000000..efa2345 --- /dev/null +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -0,0 +1,217 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + slice_offset, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + """ + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. + """ + pid = tl.program_id(axis=0) + cur_batch = tl.program_id(axis=1) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = tl.arange(0, BLOCK_K) + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride, ) + b_ptr = (lora_ptr + l0_stride * lora_index + + offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_n_stride + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + M = tl.load(seq_lens + cur_batch) + c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < + (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _sgmv_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + """_summary_ + + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4, 10]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences + in the batch + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_b_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert slice_size == lora_b_weights.size(-2) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + batches, + ) + _sgmv_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + inputs.stride(0), + inputs.stride(1), + lora_b_weights.stride(0), + lora_b_weights.stride(1), + lora_b_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + slice_offset, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + ) + return + + +try: + sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", + _sgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py new file mode 100644 index 0000000..c003f3d --- /dev/null +++ b/vllm/lora/ops/sgmv_shrink.py @@ -0,0 +1,201 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +import torch +import triton +import triton.language as tl + +from vllm.triton_utils import libentry + + +@libentry() +@triton.jit +def _sgmv_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + xm_stride, # hidden_size + xk_stride, # 1 + l0_stride, # hidden_size*max_rank + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, +): + """ + The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. + The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, + introducing SPLIT-K can improve performance + """ + pid = tl.program_id(axis=0) + pid_sk = tl.program_id(axis=1) + cur_batch = tl.program_id(axis=2) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid // cta_n_num + pid_n = pid % cta_n_num + + M = tl.load(seq_lens + cur_batch) + if pid_m * BLOCK_M > M: + return + lora_index = tl.load(lora_indices + cur_batch) + if lora_index == -1: + return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) + offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K) + + ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + + a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + + offset_k[None, :] * xk_stride) + b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + + offset_k[:, None] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < k_remaining, + other=0.0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < k_remaining, + other=0.0) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M + + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + + offset_cn[None, :] * cn_stride) + c_mask = (offset_cm[:, None] < + (cur_seq_start + M)) & (offset_cn[None, :] < N) + accumulator *= scaling + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(c_ptr, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptr, accumulator, mask=c_mask) + + +@torch.inference_mode() +def _sgmv_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + """ + + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative + sequence lengths of the sequences in the batch, used to index + into sequence. E.g., if the sequence length is [4, 6], it is + [0, 4]. + seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence + length of the sequences in the batch. + lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index + corresponding to each batch. An index of -1 means no lora should be + applied. + batches (int): batch size + max_seq_length (int): The max sequence lengths of the sequences in the + batch. + token_nums (int): The token numbers in the batch. Used to verify if the + token numbers in the inputs matches the one in the metadata. + scaling (float): Scaling factor. + """ + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(0) == token_nums + assert inputs.size(1) == lora_a_weights.size(-1) + assert b_seq_start_loc.size(0) == batches + assert lora_indices_tensor.size(0) == batches + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + # TODO tuning this config + N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 32 + SPLIT_K = 8 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + grid = ( + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + SPLIT_K, + batches, + ) + + _sgmv_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + scaling, + inputs.stride(0), + inputs.stride(1), + lora_a_weights.stride(0), + lora_a_weights.stride(1), + lora_a_weights.stride(2), + output_tensor.stride(0), + output_tensor.stride(1), + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + ) + return + + +try: + sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", + _sgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_shrink = _sgmv_shrink diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py new file mode 100644 index 0000000..7c3e273 --- /dev/null +++ b/vllm/lora/ops/utils.py @@ -0,0 +1,46 @@ +import functools +from typing import Dict + + +@functools.lru_cache +def _get_op_configs(op_type: str, batch: int, hidden_size: int): + # TODO: add optimal configurations + return None + + +def _check_divisibility(hidden_size: int): + # The bgmv_expand kernel requires that the hidden_size be divisible by + # the number below. + divisibility = [2, 4, 8, 16, 32, 64] + divisibility.sort(reverse=True) + for div in divisibility: + if hidden_size % div == 0: + return div + # hidden_size is an odd number + return 1 + + +def _get_default_config(op_type: str, batch: int, hidden_size: int): + if op_type == "expand": + return { + "BLOCK_N": 256, + "SPLIT_N": _check_divisibility(hidden_size), + "num_warps": 8 + } + else: + return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8} + + +def get_lora_op_configs(op_type: str, batch: int, + hidden_size: int) -> Dict[str, int]: + """Inspired by `fused_moe_kernel` + The return value will be a dictionary mapping an irregular grid of batch + sizes and hidden_size to configurations of the bgmv-related kernel. + NOTE: It currently only supports the default configuration. We plan to + generate optimal configurations for different hardware in the future using + scripts similar to `benchmark_moe.py`. + """ + config = _get_op_configs(op_type, batch, hidden_size) + if not config: + config = _get_default_config(op_type, batch, hidden_size) + return config diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000..1aa9fab --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,629 @@ +""" +Based on: +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. +https://arxiv.org/abs/2310.18547 +""" + +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union + +import torch + +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.bgmv_expand import bgmv_expand + from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice + from vllm.lora.ops.bgmv_shrink import bgmv_shrink + from vllm.lora.ops.sgmv_expand import sgmv_expand + from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice + from vllm.lora.ops.sgmv_shrink import sgmv_shrink + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.layers import LoRAMapping + from vllm.lora.models import LongContextLoRAContext +from vllm import _custom_ops as ops + +def compute_meta( + token_lora_tensor: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: + """ + Get the information required for the sgmv kernel. With the features: + 1. If consecutive requests in the batch use the same LoRA, this function + will combine them into a single request, improving sgmv kernel inference + performance. + 2. At the beginning of each prefill stage inference, recalculations are + needed based on the input, but only once. + """ + + lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( + token_lora_tensor, return_counts=True) + cum_result = torch.cumsum(seq_length_tensor, dim=0) + b_seq_start_tensor = torch.zeros_like(seq_length_tensor) + b_seq_start_tensor[1:].copy_(cum_result[:-1]) + max_length = seq_length_tensor.max().item() + token_nums = seq_length_tensor.sum().item() + batch_size = lora_indices_tensor.size(0) + no_lora = False + # -1 means no lora should be applied. Use `no_lora` to determine whether + # the current step requires LoRA. If LoRA is not needed, the prefill stage + # does not need to launch the triton kernel, which can improve performance + if batch_size == 1 and lora_indices_tensor == -1: + no_lora = True + return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, no_lora) + + +# TODO see if this can be vectorized +def convert_mapping( + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + long_lora_context: Passed if there are long context lora in a batch. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + long_lora_indices: Tensor of shape [batch_size] mapping + requests to RoPE offsets and rot dims for long LoRAs. + None if long context lora doesn't exist. + indices_len: List of lengths of the above tensors. It contains + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, long_lora_indices). + """ + index_mapping_indices: List[int] = list(mapping.index_mapping).copy() + embedding_indices = index_mapping_indices.copy() + lora_indices = index_mapping_indices.copy() + long_lora_offsets: Optional[torch.Tensor] = None + if long_lora_context: + long_lora_offsets = torch.zeros(len(index_mapping_indices), + device="cuda", + dtype=torch.long) + prompt_mapping: List[int] = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(index_mapping_indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) + if index_mapping_indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 + lora_indices[i] = lora_idx + if long_lora_context: + assert long_lora_offsets is not None + lora_offset: int = long_lora_context.offsets_by_lora_id.get( + index_mapping_indices[i], 0) + long_lora_offsets[i] = lora_offset + + indices_list: List[Union[List[int], torch.Tensor]] = [ + index_mapping_indices, + lora_indices, + embedding_indices, + ] + if long_lora_context: + assert long_lora_offsets is not None + indices_list.append(long_lora_offsets) + indices = torch.tensor(indices_list, dtype=torch.long, device="cuda") + prompt_mapping_tensor = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size), + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping_tensor + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + ( + sampler_indices_padded * len(sampler_indices_padded)) + long_lora_indices = None + long_lora_indices_len: Optional[int] = None + if long_lora_context: + long_lora_indices = indices[3] + long_lora_indices_len = long_lora_indices.shape[-1] + # Contain length of indices tensors. Used to index into each tensor. + indices_len = [ + base_indices.shape[-1], + sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1], + ] + if long_lora_indices_len is not None: + indices_len.append(long_lora_indices_len) + else: + # If long_lora doesn't exist,append None + indices_len.append(None) + + return ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_indices, + indices_len, + ) + + +class PunicaWrapper: + """ + PunicaWrapper is designed to manage and provide metadata for the punica + kernel. The main function is to maintain the state information for + Multi-LoRA, and to provide the interface for the punica kernel. + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: str): + self._token_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._sampler_indices_padded = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + self._embeddings_indices = torch.empty(2, + max_num_batched_tokens, + dtype=torch.long, + device=device) + self._long_lora_indices = torch.empty(max_num_batched_tokens, + dtype=torch.long, + device=device) + + # 5 is the number of indicies tensors. + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices,long_lora_indices + self.indices_len: List[Optional[int]] = [None] * 5 + # these attributes are the information required for sgmv kernel + self._seq_start_locs = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._seq_lengths = torch.empty(max_batches, + dtype=torch.long, + device=device) + self._lora_indices_per_batch = torch.empty(max_batches, + dtype=torch.long, + device=device) + self.max_length: int = 0 + self.token_nums: int = 0 + self.batch_size: int = -1 + self.is_prefill = False + self.no_lora = False + + def update_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + + self._update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + if mapping.is_prefill: + # Update metadata required for prefill-related operators. + self._update_prefill_metada(self.token_lora_indices) + self.is_prefill = True + else: + self.is_prefill = False + + def _update_base_metadata( + self, + mapping: "LoRAMapping", + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + ): + ( + base_indices, + sampler_indices, + sampler_indices_padded, + embeddings_indices, + long_lora_offsets_tensor, + indices_len, + ) = convert_mapping( + mapping, + lora_index_to_id, + max_loras, + vocab_size, + extra_vocab_size, + long_lora_context, + ) + self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) + self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self._embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + if long_lora_offsets_tensor is not None: + self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( + long_lora_offsets_tensor) + else: + self._long_lora_indices.zero_() + self.indices_len[:] = indices_len + + def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: + + (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, + batch_size, max_length, token_nums, + no_lora) = compute_meta(token_lora_tensor) + + self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + b_seq_start_tensor) + self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + lora_indices_tensor) + self.batch_size = batch_size + self.max_length = max_length + self.token_nums = token_nums + self.no_lora = no_lora + + @property + def prefill_metadata( + self + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]: + """ + This property provides a convenient way to access the necessary + metadata for prefill-related kernel computations. + 1. seq_start_locs: Tensor of sequence start positions. + 2. seq_lengths: Tensor of sequence lengths. + 3. lora_indices_per_batch: Tensor of lora indices, and an index of + -1 means no lora should be applied. + 4. batch_size: Batch size after clustering identical lora indices. + 5. max_length: The maximum sequence length in the batch. + 6. token_nums: The token numbers in the batch. + """ + return (self._seq_start_locs[:self.batch_size], + self._seq_lengths[:self.batch_size], + self._lora_indices_per_batch[:self.batch_size], + self.batch_size, self.max_length, self.token_nums) + + @property + def token_lora_indices(self) -> torch.Tensor: + """ + This property provides the lora indices corresponding to each token + in the batch. An index of -1 means no lora should be applied. + """ + token_lora_len = self.indices_len[0] + return self._token_lora_indices[:token_lora_len] + + @property + def sampler_indices(self) -> torch.Tensor: + """ + This property is used to access the lora indices specifically for + LogitsProcessorWithLoRA. + """ + sampler_indices_len = self.indices_len[1] + return self._sampler_indices[:sampler_indices_len] + + @property + def sampler_indices_padded(self) -> torch.Tensor: + """ + This property provides access to padded sampler indices. + """ + indices_padded_len = self.indices_len[2] + return self._sampler_indices_padded[:indices_padded_len] + + @property + def embeddings_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for lora embeddings, + specifically for VocabParallelEmbeddingWithLoRA. + """ + embeddings_indices_len = self.indices_len[3] + return self._embeddings_indices[:, :embeddings_indices_len] + + @property + def long_lora_indices(self) -> torch.Tensor: + """ + This property provides access to the indices used for long context + lora, specifically for LinearScalingRotaryEmbeddingWithLora. + """ + long_lora_len = self.indices_len[4] + return self._long_lora_indices[:long_lora_len] + + def shrink_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_shrink( + x, + w_t_all, + y, + *self.prefill_metadata, + scale, + ) + + def shrink_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) + + def expand_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand( + x, + w_t_all, + y, + *self.prefill_metadata, + add_input, + ) + + def expand_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool, + ): + bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input) + + def expand_slice_prefill( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + sgmv_expand_slice( + x, + w_t_all, + y, + *self.prefill_metadata, + y_offset, + y_slice_size, + add_input, + ) + + def expand_slice_decode( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool, + ): + bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, + y_slice_size, add_input) + + def add_shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the shrink_decode function + should be called. + """ + # shrink_fun: Callable = (self.shrink_prefill + # if self.is_prefill else self.shrink_decode) + # shrink_fun(y, x, w_t_all, scale) + if self.is_prefill: + if self.no_lora: + return y + y = ops.sbgmv_shrink(x, w_t_all, y, *self.prefill_metadata, scale=scale) + else: + y = ops.sbgmv_shrink(x, w_t_all, y, lora_indices_tensor=self.token_lora_indices, scale=scale) + return y + + def add_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_input: bool = True, + ): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'b. + When `is_prefill` is true, it indicates that it is currently the + prefill stage, and the `expand_prefill` function should be called. + Otherwise, it is the decode stage, and the expand_decode function + should be called. + """ + + # expand_fun: Callable = (self.expand_prefill + # if self.is_prefill else self.expand_decode) + # expand_fun(y, x, w_t_all, add_input) + if self.is_prefill: + if self.no_lora: + return y + y = ops.sbgmv_expand(x, w_t_all, y, *self.prefill_metadata, add_input=add_input) + else: + y = ops.sbgmv_expand(x, w_t_all, y, lora_indices_tensor=self.token_lora_indices, add_input=add_input) + + def add_expand_slice(self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_input: bool = True): + """ + Similar to `add_expand` + """ + + # expand_slice_fun: Callable = (self.expand_slice_prefill + # if self.is_prefill else + # self.expand_slice_decode) + # expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + if self.is_prefill: + if self.no_lora: + return y + ops.sbgmv_expand(x, w_t_all, y[:, y_offset:y_offset+y_slice_size], *self.prefill_metadata, add_input=add_input) + else: + ops.sbgmv_expand(x, w_t_all, y[:, y_offset:y_offset+y_slice_size], lora_indices_tensor=self.token_lora_indices, add_input=add_input) + + def add_lora(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale: float, + y_offset: Optional[int] = None, + y_slice_size: Optional[int] = None, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + wa_t_all (torch.Tensor): lora_a's weight + wb_t_all (torch.Tensor): lora_b's weight + scale (float): Scaling factor. + y_offset (Optional[int], optional): Offset to apply to the starting + column of y. + y_slice_size (Optional[int], optional): Size of the y column slice. + buffer (Optional[torch.Tensor], optional): Defaults to None. + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + buffer = self.add_shrink(buffer, x, wa_t_all, scale) + if y_offset is None and y_slice_size is None: + self.add_expand(y, buffer, wb_t_all, add_input=True) + else: + self.add_expand_slice(y, + buffer, + wb_t_all, + y_offset, + y_slice_size, + add_input=True) + y = y.view_as(y_org) + + def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, + torch.Tensor, + torch.Tensor], + scale: float, + output_slices: Tuple[int, ...]) -> None: + """ + Applies lora to each input. Similar to add_lora, This method is + used for layers that are composed of multiple sublayers + (slices) packed together. + """ + y_org = y + x = x.view(-1, x.shape[-1]) + y = y.view(-1, y.shape[-1]) + offset_left = 0 + # TODO fuse these kernels + for slice_idx in range(len(output_slices)): + self.add_lora(y, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], scale, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None) -> None: + """ + LogitsProcessorWithLoRA always using bgmv + """ + y_org = y + y = y.view(-1, y.shape[-1]) + x = x.view(-1, x.shape[-1]) + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + # bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale) + buffer = ops.sbgmv_shrink(x, wa_t_all, buffer, lora_indices_tensor=self.sampler_indices, scale=scale) + # bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True) + y = ops.sbgmv_expand(buffer, wb_t_all, y, lora_indices_tensor=self.sampler_indices, add_input=True) + y = y.view_as(y_org) diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000..c4b26dc --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,95 @@ +import warnings +from typing import Optional + +import msgspec + +from vllm.adapter_commons.request import AdapterRequest + + +class LoRARequest( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """ + Request for a LoRA adapter. + + Note that this class should be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + __metaclass__ = AdapterRequest + + lora_name: str + lora_int_id: int + lora_path: str = "" + lora_local_path: Optional[str] = msgspec.field(default=None) + long_lora_max_len: Optional[int] = None + base_model_name: Optional[str] = msgspec.field(default=None) + + def __post_init__(self): + if 'lora_local_path' in self.__struct_fields__: + warnings.warn( + "The 'lora_local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'lora_path' instead.", + DeprecationWarning, + stacklevel=2) + if not self.lora_path: + self.lora_path = self.lora_local_path or "" + + # Ensure lora_path is not empty + assert self.lora_path, "lora_path cannot be empty" + + @property + def adapter_id(self): + return self.lora_int_id + + @property + def name(self): + return self.lora_name + + @property + def path(self): + return self.lora_path + + @property + def local_path(self): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + return self.lora_path + + @local_path.setter + def local_path(self, value): + warnings.warn( + "The 'local_path' attribute is deprecated " + "and will be removed in a future version. " + "Please use 'path' instead.", + DeprecationWarning, + stacklevel=2) + self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000..a780429 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,189 @@ +import os +import re +from typing import List, Optional, Set, Tuple, Type, Union + +import huggingface_hub +from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, + HFValidationError, RepositoryNotFoundError) +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LinearScalingRotaryEmbeddingWithLora, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + ReplicatedLinearWithLoRA, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead + +logger = init_logger(__name__) + +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, + LogitsProcessorWithLoRA, + ColumnParallelLinearWithShardedLoRA, + QKVParallelLinearWithShardedLora, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, + RowParallelLinearWithShardedLoRA, + LinearScalingRotaryEmbeddingWithLora, +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device, + lm_head.get_sharded_to_full_mapping()) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + + if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model": + if parts[-1] == "weight": + if parts[-2] == "lora_A" or parts[-2] == "lora_B": + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported LoRA weight") + + +def is_regex_target_modules(load_modules: Union[str, List[str]], + expected_lora_modules: List[str]) -> bool: + """ + PEFT supports passing `target_modules` in the form of regular expressions, + such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to + determine whether the suffix in the regular expression is present in the + `expected_lora_modules`. + """ + + def is_valid_regex(pattern): + try: + re.compile(pattern) + return True + except re.error: + return False + + def is_subset(sub_list, full_list): + return set(sub_list).issubset(set(full_list)) + + # Similar to PEFT's processing logic, regex-related operations are only + # executed when the load_modules is a `str`. + if not isinstance(load_modules, str): + return False + + if is_valid_regex(load_modules): + match = re.search(r"\((.*?)\)\$?$", load_modules) + if match: + suffix = match.group(1).split("|") + return is_subset(suffix, expected_lora_modules) + return False + + +def get_adapter_absolute_path(lora_path: str) -> str: + """ + Resolves the given lora_path to an absolute local path. + + If the lora_path is identified as a Hugging Face model identifier, + it will download the model and return the local snapshot path. + Otherwise, it treats the lora_path as a local file path and + converts it to an absolute path. + + Parameters: + lora_path (str): The path to the lora model, which can be an absolute path, + a relative path, or a Hugging Face model identifier. + + Returns: + str: The resolved absolute local path to the lora model. + """ + + # Check if the path is an absolute path. Return it no matter exists or not. + if os.path.isabs(lora_path): + return lora_path + + # If the path starts with ~, expand the user home directory. + if lora_path.startswith('~'): + return os.path.expanduser(lora_path) + + # Check if the expanded relative path exists locally. + if os.path.exists(lora_path): + return os.path.abspath(lora_path) + + # If the path does not exist locally, assume it's a Hugging Face repo. + try: + local_snapshot_path = huggingface_hub.snapshot_download( + repo_id=lora_path) + except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError, + HFValidationError): + # Handle errors that may occur during the download + # Return original path instead instead of throwing error here + logger.exception("Error downloading the HuggingFace model") + return lora_path + + return local_snapshot_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000..724c308 --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,212 @@ +from contextlib import contextmanager +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import LoRAConfig +from vllm.logger import init_logger +from vllm.lora.models import (LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.request import LoRARequest +from vllm.lora.utils import get_adapter_absolute_path + +logger = init_logger(__name__) + + +class WorkerLoRAManager(AbstractWorkerManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + embedding_modules: Dict[str, str], + embedding_padding_modules: List[str], + lora_model_cls: Type[LoRAModel] = LoRAModel, + max_position_embeddings: Optional[int] = None, + ): + self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) + # Lazily initialized by create_lora_manager. + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._manager_cls, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: + try: + model = self._adapter_manager.model + supported_lora_modules = model.supported_lora_modules + packed_modules_mapping = model.packed_modules_mapping + expected_lora_modules: List[str] = [] + for module in supported_lora_modules: + if module in packed_modules_mapping: + expected_lora_modules.extend( + packed_modules_mapping[module]) + else: + expected_lora_modules.append(module) + lora_path = get_adapter_absolute_path(lora_request.lora_path) + lora = self._lora_model_cls.from_local_checkpoint( + lora_path, + expected_lora_modules, + max_position_embeddings=self.max_position_embeddings, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + ) + except Exception as e: + raise RuntimeError(f"Loading lora {lora_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} " + f"is greater than lora_extra_vocab_size " + f"{self.lora_config.lora_extra_vocab_size}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_adapters(): + return False + if isinstance(self._cached_dummy_lora, LoRAModel): + dummy_lora = self._cached_dummy_lora.clone( + lora_request.lora_int_id) + else: + dummy_lora = self._adapter_manager.create_dummy_lora( + lora_request.lora_int_id, rank, 1, self.embedding_modules) + if self._cached_dummy_lora is None: + self._cached_dummy_lora = dummy_lora + return self._adapter_manager.add_adapter(dummy_lora) + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + lora_manager_cls=self._manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._adapter_manager = lora_manager + return lora_manager.model + + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._adapter_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._adapter_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_adapter(lora) + + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): + # Remove before we load the new lora to save memory + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + lora = self._load_adapter(lora_request) + loaded = self._adapter_manager.add_adapter(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + lora_request.lora_int_id) is not None + self._adapter_manager.activate_adapter(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py new file mode 100644 index 0000000..7278c7f --- /dev/null +++ b/vllm/model_executor/__init__.py @@ -0,0 +1,13 @@ +from vllm.model_executor.parameter import (BasevLLMParameter, + PackedvLLMParameter) +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingMetadataCache) +from vllm.model_executor.utils import set_random_seed + +__all__ = [ + "SamplingMetadata", + "SamplingMetadataCache", + "set_random_seed", + "BasevLLMParameter", + "PackedvLLMParameter", +] diff --git a/vllm/model_executor/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3e896cf Binary files /dev/null and b/vllm/model_executor/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/__pycache__/custom_op.cpython-310.pyc b/vllm/model_executor/__pycache__/custom_op.cpython-310.pyc new file mode 100644 index 0000000..3548469 Binary files /dev/null and b/vllm/model_executor/__pycache__/custom_op.cpython-310.pyc differ diff --git a/vllm/model_executor/__pycache__/parameter.cpython-310.pyc b/vllm/model_executor/__pycache__/parameter.cpython-310.pyc new file mode 100644 index 0000000..fe4df37 Binary files /dev/null and b/vllm/model_executor/__pycache__/parameter.cpython-310.pyc differ diff --git a/vllm/model_executor/__pycache__/pooling_metadata.cpython-310.pyc b/vllm/model_executor/__pycache__/pooling_metadata.cpython-310.pyc new file mode 100644 index 0000000..354b163 Binary files /dev/null and b/vllm/model_executor/__pycache__/pooling_metadata.cpython-310.pyc differ diff --git a/vllm/model_executor/__pycache__/sampling_metadata.cpython-310.pyc b/vllm/model_executor/__pycache__/sampling_metadata.cpython-310.pyc new file mode 100644 index 0000000..c02d291 Binary files /dev/null and b/vllm/model_executor/__pycache__/sampling_metadata.cpython-310.pyc differ diff --git a/vllm/model_executor/__pycache__/utils.cpython-310.pyc b/vllm/model_executor/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..4036015 Binary files /dev/null and b/vllm/model_executor/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py new file mode 100644 index 0000000..d0e9024 --- /dev/null +++ b/vllm/model_executor/custom_op.py @@ -0,0 +1,71 @@ +import torch.nn as nn + +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.platforms import current_platform +from vllm.utils import is_cpu, is_hip, is_xpu + + +class CustomOp(nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + """PyTorch-native implementation of the forward method. + + This method is optional. If implemented, it can be used with compilers + such as torch.compile or PyTorch XLA. Also, it can be used for testing + purposes. + """ + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + # By default, we assume that HIP ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_xpu(self, *args, **kwargs): + # By default, we assume that XPU ops are compatible with the + # PyTorch-native implementation. + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + # By default, we assume that CPU ops are compatible with CUDA ops. + return self.forward_cuda(*args, **kwargs) + + def forward_tpu(self, *args, **kwargs): + # By default, we assume that TPU ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def forward_gaudi(self, *args, **kwargs): + # By default, we assume that Gaudi ops are compatible with the + # PyTorch-native implementation. + # NOTE(woosuk): This is a placeholder for future extensions. + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + # NOTE(woosuk): Here we assume that vLLM was built for only one + # specific backend. Currently, we do not support dynamic dispatching. + + if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR: + return self.forward_native + + if is_hip(): + return self.forward_hip + elif is_cpu(): + return self.forward_cpu + elif current_platform.is_tpu(): + return self.forward_tpu + elif is_xpu(): + return self.forward_xpu + else: + return self.forward_cuda diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py new file mode 100644 index 0000000..368436a --- /dev/null +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -0,0 +1,45 @@ +from typing import Optional + +from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor + + +async def get_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer) -> Optional[LogitsProcessor]: + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines' or guided_params.grammar: + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_outlines_guided_decoding_logits_processor) + return await get_outlines_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") + + +def get_local_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer) -> Optional[LogitsProcessor]: + # CFG grammar not supported by LMFE, so we use outlines instead + if guided_params.backend == 'outlines' or guided_params.grammar: + # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 + from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa + get_local_outlines_guided_decoding_logits_processor) + return get_local_outlines_guided_decoding_logits_processor( + guided_params, tokenizer) + if guided_params.backend == 'lm-format-enforcer': + from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa + get_local_lm_format_enforcer_guided_decoding_logits_processor) + return get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params, tokenizer) + + raise ValueError( + f"Unknown guided decoding backend '{guided_params.backend}'. " + "Must be one of 'outlines, 'lm-format-enforcer'") diff --git a/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..35cbb04 Binary files /dev/null and b/vllm/model_executor/guided_decoding/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-310.pyc b/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-310.pyc new file mode 100644 index 0000000..8db14b7 Binary files /dev/null and b/vllm/model_executor/guided_decoding/__pycache__/guided_fields.cpython-310.pyc differ diff --git a/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-310.pyc b/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-310.pyc new file mode 100644 index 0000000..071dcb0 Binary files /dev/null and b/vllm/model_executor/guided_decoding/__pycache__/lm_format_enforcer_decoding.cpython-310.pyc differ diff --git a/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-310.pyc b/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-310.pyc new file mode 100644 index 0000000..8de5d11 Binary files /dev/null and b/vllm/model_executor/guided_decoding/__pycache__/outlines_decoding.cpython-310.pyc differ diff --git a/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-310.pyc b/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-310.pyc new file mode 100644 index 0000000..833a519 Binary files /dev/null and b/vllm/model_executor/guided_decoding/__pycache__/outlines_logits_processors.cpython-310.pyc differ diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py new file mode 100644 index 0000000..8deb4c9 --- /dev/null +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, TypedDict, Union + +from pydantic import BaseModel + + +# These classes are deprecated, see SamplingParams +class LLMGuidedOptions(TypedDict, total=False): + guided_json: Union[Dict, BaseModel, str] + guided_regex: str + guided_choice: List[str] + guided_grammar: str + guided_decoding_backend: str + guided_whitespace_pattern: str + guided_json_object: bool + + +@dataclass +class GuidedDecodingRequest: + """One of the fields will be used to retrieve the logit processor.""" + guided_json: Optional[Union[Dict, BaseModel, str]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + guided_grammar: Optional[str] = None + guided_decoding_backend: Optional[str] = None + guided_whitespace_pattern: Optional[str] = None + guided_json_object: Optional[bool] = None + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.guided_json is not None, self.guided_regex is not None, + self.guided_choice is not None, self.guided_grammar is not None, + self.guided_json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py new file mode 100644 index 0000000..cf2162e --- /dev/null +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -0,0 +1,63 @@ +from functools import lru_cache +from json import loads as json_loads +from typing import Optional, Union + +from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser, + RegexParser, StringParser, + TokenEnforcerTokenizerData, UnionParser) +from lmformatenforcer.integrations.vllm import ( + build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data) +from transformers import PreTrainedTokenizerBase + +from vllm.sampling_params import GuidedDecodingParams, LogitsProcessor + + +def get_local_lm_format_enforcer_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer) -> Optional[LogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + + tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer) + character_level_parser: CharacterLevelParser + if guided_params.json: + schema_dict = _normalize_json_schema_object(guided_params.json) + character_level_parser = JsonSchemaParser(schema_dict) + elif guided_params.choice: + character_level_parser = UnionParser( + [StringParser(choice) for choice in guided_params.choice]) + elif guided_params.regex: + character_level_parser = RegexParser(guided_params.regex) + elif guided_params.grammar: + # CFG grammar not supported by LMFE + raise ValueError("Cannot construct a guided decoding logits processor" + " using the grammar option with the" + " lm_format_enforcer backend.") + elif guided_params.json_object: + # None means any json object + character_level_parser = JsonSchemaParser(None) + else: + return None + + logits_processor = build_vllm_logits_processor(tokenizer_data, + character_level_parser) + return logits_processor + + +def _normalize_json_schema_object(schema: Union[str, dict]) -> dict: + if isinstance(schema, str): + return json_loads(schema) + if isinstance(schema, dict): + return schema + raise AssertionError(f"Unsupported schema type {schema}") + + +@lru_cache +def _cached_build_vllm_token_enforcer_tokenizer_data( + tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData: + return build_vllm_token_enforcer_tokenizer_data(tokenizer) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py new file mode 100644 index 0000000..8a7ff38 --- /dev/null +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -0,0 +1,133 @@ +import asyncio +import concurrent.futures +from enum import Enum +from json import dumps as json_dumps +from re import escape as regex_escape +from typing import Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.model_executor.guided_decoding.outlines_logits_processors import ( + CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams + + +class GuidedDecodingMode(Enum): + JSON = "json" + REGEX = "regex" + CHOICE = "choice" + GRAMMAR = "grammar" + + +# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark +# the main difference is that we changed the start: value to +# start: object | array, so we are denying scalar values as the root of the +# JSON. Starting with scalars as the root seems to cause llama to generate +# without stop. +JSON_GRAMMAR = r""" +?start: object | array + +?value: object +| array +| UNESCAPED_STRING +| SIGNED_NUMBER -> number +| "true" -> true +| "false" -> false +| "null" -> null + +array : "[" [value ("," value)*] "]" +object : "{" [pair ("," pair)*] "}" +pair : UNESCAPED_STRING ":" value + +%import common.UNESCAPED_STRING +%import common.SIGNED_NUMBER +%import common.WS + +%ignore WS +""" + +global_thread_pool = None # used for generating logits processor fsm + + +async def get_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + global global_thread_pool + guide, mode = _get_guide_and_mode(guided_params) + if not guide or not mode: + return None + + if global_thread_pool is None: + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=2) + loop = asyncio.get_running_loop() + + return await loop.run_in_executor(global_thread_pool, + _get_logits_processor, guide, tokenizer, + mode, guided_params.whitespace_pattern) + + +def get_local_outlines_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizerBase +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, + None]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + guide, mode = _get_guide_and_mode(guided_params) + if not guide or not mode: + return None + + return _get_logits_processor(guide, tokenizer, mode, + guided_params.whitespace_pattern) + + +def _get_guide_and_mode( + guided_params: GuidedDecodingParams +) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: + if guided_params.json: + if isinstance(guided_params.json, dict): + # turn dict into hashable string + json = json_dumps(guided_params.json) + else: + json = guided_params.json + return json, GuidedDecodingMode.JSON + elif guided_params.regex: + return guided_params.regex, GuidedDecodingMode.REGEX + elif guided_params.choice: + # choice just uses regex + choices = [ + regex_escape(str(choice)) for choice in guided_params.choice + ] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex, GuidedDecodingMode.CHOICE + elif guided_params.grammar: + return guided_params.grammar, GuidedDecodingMode.GRAMMAR + elif guided_params.json_object: + return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + else: + return None, None + + +def _get_logits_processor( + guide: str, tokenizer: PreTrainedTokenizerBase, mode: GuidedDecodingMode, + whitespace_pattern: Union[str, None] +) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: + if mode == GuidedDecodingMode.JSON: + return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern) + elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: + return RegexLogitsProcessor(guide, tokenizer) + elif mode == GuidedDecodingMode.GRAMMAR: + return CFGLogitsProcessor(guide, tokenizer) + else: + raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py new file mode 100644 index 0000000..c28bd71 --- /dev/null +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -0,0 +1,214 @@ +# Copyright 2024- the Outlines developers +# This file is adapted from +# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import json +import math +from collections import defaultdict +from functools import lru_cache +from typing import Callable, DefaultDict, Dict, List, Union + +import torch +from lark import Lark +from outlines import grammars +from outlines.caching import cache +from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write +from outlines.fsm.json_schema import build_regex_from_schema +from pydantic import BaseModel +from transformers import PreTrainedTokenizerBase + + +class BaseLogitsProcessor: + + def __init__(self, guide: Guide): + self._guide: Guide = guide + self._fsm_state: DefaultDict[int, int] = defaultdict(int) + + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + seq_id = hash(tuple(input_ids)) + + if len(input_ids) > 0: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self._fsm_state[seq_id] = self._guide.get_next_state( + state=self._fsm_state[last_seq_id], token_id=last_token) + else: + # Note: this is a hack. + # Lark pickling does not work properly (silent failure), + # which breaks the RPC (which uses python pickleing). + # We need to find a better solution. + # On the first time this is called, we simply re-create + # the Lark object. + if isinstance(self._guide, CFGGuide): + self._guide.parser = Lark( + self._guide.cfg_string, + parser="lalr", + lexer="contextual", + propagate_positions=False, + maybe_placeholders=False, + regex=True, + import_paths=[grammars.GRAMMAR_PATH], + ) + + instruction = self._guide.get_next_instruction( + state=self._fsm_state[seq_id]) + + if type(instruction) == Generate: # noqa: E721 + allowed_tokens = instruction.tokens + elif type(instruction) == Write: # noqa: E721 + # TODO: support fast forward tokens + allowed_tokens = [instruction.tokens[0]] + else: + raise TypeError( + f"Unsupported instruction type {type(instruction)}") + + mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) + mask[allowed_tokens] = 0 + scores.add_(mask) + return scores + + +class RegexLogitsProcessor(BaseLogitsProcessor): + + @classmethod + @cache() + def _get_guide(cls, regex_string: str, + tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return RegexGuide(regex_string, tokenizer) + + def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer + + """ + super().__init__( + RegexLogitsProcessor._get_guide(regex_string, tokenizer)) + + +class JSONLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, schema: Union[str, Dict, BaseModel], + tokenizer: PreTrainedTokenizerBase, + whitespace_pattern: Union[str, None]): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to + generate + tokenizer + The model's tokenizer + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact + string literals) + Example: allow only a single space or newline with + `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + f"a Pydantic object, a dictionary or a string that contains " + f"the JSON Schema specification") + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer) + + +class CFGLogitsProcessor(BaseLogitsProcessor): + + @classmethod + @cache() + def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: + tokenizer = _adapt_tokenizer(tokenizer) + return CFGGuide(cfg, tokenizer) + + def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): + """Compile the FSM that drives the context free grammar generation. + + Parameters + ---------- + cfg + A string that represents a context-free grammar + tokenizer + The model's tokenizer + + """ + super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer)) + self._guide = self._guide.copy() + + +@lru_cache(maxsize=32) +def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. The decoder of outlines, returns a list whereas + the decode of vLLM returns an str. To sync the vLLM decoder with + outlines internal api, the decoder should be adapted. In addition + we need to handle the missing spaces to Llama's tokenizer to be + able to compile FSMs for this model. + + """ + if getattr(tokenizer, "_outlines_adapted", False): + return tokenizer + + tokenizer = copy.deepcopy(tokenizer) + + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + def change_decoder( + decoder: Callable[[List[int]], + str]) -> Callable[[List[int]], List[str]]: + """Sync vLLM's decoder with the outlines by returning list.""" + + def new_decoder(inp_tokens: List[int]) -> List[str]: + return [decoder(inp_tokens)] + + return new_decoder + + tokenizer.convert_token_to_string = convert_token_to_string + tokenizer.decode = change_decoder(tokenizer.decode) + setattr(tokenizer, "_outlines_adapted", True) # noqa: B010 + + return tokenizer diff --git a/vllm/model_executor/layers/__init__.py b/vllm/model_executor/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..49d5dd0 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/activation.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/activation.cpython-310.pyc new file mode 100644 index 0000000..a6a2ab0 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/activation.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc new file mode 100644 index 0000000..0f09c8c Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc new file mode 100644 index 0000000..be35bf4 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/logits_processor.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/logits_processor.cpython-310.pyc new file mode 100644 index 0000000..bf46f27 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/logits_processor.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/pooler.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/pooler.cpython-310.pyc new file mode 100644 index 0000000..a97b752 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/pooler.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/rejection_sampler.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/rejection_sampler.cpython-310.pyc new file mode 100644 index 0000000..e99e319 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/rejection_sampler.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/resampler.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/resampler.cpython-310.pyc new file mode 100644 index 0000000..7402f04 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/resampler.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/rotary_embedding.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/rotary_embedding.cpython-310.pyc new file mode 100644 index 0000000..4e5f06b Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/rotary_embedding.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/sampler.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/sampler.cpython-310.pyc new file mode 100644 index 0000000..22d4863 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/sampler.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/spec_decode_base_sampler.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/spec_decode_base_sampler.cpython-310.pyc new file mode 100644 index 0000000..879c57f Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/spec_decode_base_sampler.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/typical_acceptance_sampler.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/typical_acceptance_sampler.cpython-310.pyc new file mode 100644 index 0000000..6c3ed20 Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/typical_acceptance_sampler.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/__pycache__/vocab_parallel_embedding.cpython-310.pyc b/vllm/model_executor/layers/__pycache__/vocab_parallel_embedding.cpython-310.pyc new file mode 100644 index 0000000..2c1fcda Binary files /dev/null and b/vllm/model_executor/layers/__pycache__/vocab_parallel_embedding.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py new file mode 100644 index 0000000..4305678 --- /dev/null +++ b/vllm/model_executor/layers/activation.py @@ -0,0 +1,252 @@ +"""Custom activation functions.""" +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.utils import set_weight_attrs + + +class SiluAndMul(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, x) + return out + + +class GeluAndMul(CustomOp): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def __init__(self, approximate: str = "none"): + super().__init__() + self.approximate = approximate + if approximate not in ("none", "tanh"): + raise ValueError(f"Unknown approximate mode: {approximate}") + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if self.approximate == "none": + ops.gelu_and_mul(out, x) + elif self.approximate == "tanh": + ops.gelu_tanh_and_mul(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + if self.approximate == "none": + ops.gelu_and_mul(out, x) + elif self.approximate == "tanh": + ops.gelu_tanh_and_mul(out, x) + return out + + def extra_repr(self) -> str: + return f'approximate={repr(self.approximate)}' + + +class NewGELU(CustomOp): + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * + (x + 0.044715 * torch.pow(x, 3.0)))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + + out = torch.empty_like(x) + ops.gelu_new(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + return ops.gelu_new(x) + + +class FastGELU(CustomOp): + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * + (1.0 + 0.044715 * x * x))) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + + out = torch.empty_like(x) + ops.gelu_fast(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + return ops.gelu_fast(x) + + +class QuickGELU(CustomOp): + + # https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90 + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + from vllm import _custom_ops as ops + + out = torch.empty_like(x) + ops.gelu_quick(out, x) + return out + + def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + from vllm._ipex_ops import ipex_ops as ops + + out = torch.empty_like(x) + ops.gelu_quick(out, x) + return out + + # TODO implement forward_xpu for QuickGELU + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + + +class ReLUSquaredActivation(CustomOp): + """ + Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2 + """ + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return torch.square(F.relu(x)) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + return self.forward_native(x) + + +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.act = act_module + self.input_is_parallel = input_is_parallel + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, + tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.scales = nn.Parameter( + torch.empty(intermediate_size_per_partition, dtype=params_dtype)) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(x) / self.scales + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_fast": FastGELU(), + "gelu_new": NewGELU(), + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), + "relu": nn.ReLU(), + "relu2": ReLUSquaredActivation(), + "quick_gelu": QuickGELU(), +} + + +def get_act_fn( + act_fn_name: str, + quant_config: Optional[QuantizationConfig] = None, + intermediate_size: Optional[int] = None, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + act_fn = _ACTIVATION_REGISTRY[act_fn_name] + if (quant_config is not None + and act_fn_name in quant_config.get_scaled_act_names()): + if intermediate_size is None: + raise ValueError("intermediate_size must be specified for scaled " + "activation functions.") + return ScaledActivation(act_fn, intermediate_size, input_is_parallel, + params_dtype) + return act_fn diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 0000000..031c2ca --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,26 @@ +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.triton_utils import HAS_TRITON + +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", +] + +# if HAS_TRITON: +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) + +__all__ += [ + "fused_marlin_moe", + "single_marlin_moe", + "fused_moe", + "fused_topk", + "fused_experts", + "get_config_file_name", + "grouped_topk", +] diff --git a/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..51e1059 Binary files /dev/null and b/vllm/model_executor/layers/fused_moe/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-310.pyc b/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-310.pyc new file mode 100644 index 0000000..eea9303 Binary files /dev/null and b/vllm/model_executor/layers/fused_moe/__pycache__/fused_marlin_moe.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-310.pyc b/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-310.pyc new file mode 100644 index 0000000..ed50e46 Binary files /dev/null and b/vllm/model_executor/layers/fused_moe/__pycache__/fused_moe.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-310.pyc b/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-310.pyc new file mode 100644 index 0000000..a3b03b7 Binary files /dev/null and b/vllm/model_executor/layers/fused_moe/__pycache__/layer.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-310.pyc b/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-310.pyc new file mode 100644 index 0000000..a09e693 Binary files /dev/null and b/vllm/model_executor/layers/fused_moe/__pycache__/moe_pallas.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py new file mode 100644 index 0000000..5964d5a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -0,0 +1,313 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size, try_get_optimal_moe_config) +from vllm.scalar_type import scalar_types + + +def get_scalar_type(num_bits: int, has_zp: bool): + if has_zp: + assert num_bits == 4 + return scalar_types.uint4 + else: + return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 + + +def single_marlin_moe( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + g_idx: Optional[torch.Tensor] = None, + sort_indices: Optional[torch.Tensor] = None, + w_zeros: Optional[torch.Tensor] = None, + override_config: Optional[Dict[str, Any]] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + """ + This function computes the multiplication of hidden_states with expert + weights used in Marlin MoE, using weights w and top-k gating mechanism. + Its purpose is testing and debugging the fused MoE kernel. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. + - w (torch.Tensor): The set of expert weights. + - scales (torch.Tensor): The quantization scales. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx (Optional[torch.Tensor]): Optional act_order indices. + - sort_indices (Optional[torch.Tensor]): Optional act_order input + permutation. + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - w_zeros (Optional[torch.Tensor]): Optional zero points to be used for w. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - num_bits (bool): The number of bits in expert weights quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // (num_bits // 2) + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device=hidden_states.device, + requires_grad=False) + + has_zero_point = w_zeros is not None + if w_zeros is None: + w_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if g_idx is None: + g_idx = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + if sort_indices is None: + sort_indices = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type = get_scalar_type(num_bits, has_zero_point) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + w_zeros, g_idx, sort_indices, workspace, scalar_type, M, N, K, + is_k_full, E, topk, block_size_m, True, False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + gating_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + g_idx1: Optional[torch.Tensor] = None, + g_idx2: Optional[torch.Tensor] = None, + sort_indices1: Optional[torch.Tensor] = None, + sort_indices2: Optional[torch.Tensor] = None, + w1_zeros: Optional[torch.Tensor] = None, + w2_zeros: Optional[torch.Tensor] = None, + override_config: Optional[Dict[str, Any]] = None, + num_bits: int = 8, + is_k_full: bool = True, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - w1_scale (torch.Tensor): Scale to be used for w1. + - w2_scale (torch.Tensor): Scale to be used for w2. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx1 (Optional[torch.Tensor]): The first set of act_order indices. + - g_idx2 (Optional[torch.Tensor]): The second set of act_order indices. + - sort_indices1 (Optional[torch.Tensor]): The first act_order input + permutation. + - sort_indices2 (Optional[torch.Tensor]): The second act_order input + permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - w1_zeros (Optional[torch.Tensor]): Optional zero points to be used for w1. + - w2_zeros (Optional[torch.Tensor]): Optional zero points to be used for w2. + - num_bits (bool): The number of bits in expert weights quantization. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[1] == w2.shape[2] // ( + num_bits // 2), "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype == torch.float16 + assert num_bits in [4, 8] + + has_no_act_order = (g_idx1 is None and g_idx2 is None + and sort_indices1 is None and sort_indices2 is None) + has_all_act_order = (g_idx1 is not None and g_idx2 is not None + and sort_indices1 is not None + and sort_indices2 is not None) + assert has_no_act_order or has_all_act_order, ( + "g_idx and sorted_indices " + "must be all not None or must be all None") + + has_no_zp = w1_zeros is None and w2_zeros is None + has_all_zp = w1_zeros is not None and w2_zeros is not None + assert has_no_zp or has_all_zp, ("zero points must be both not None or " + "must be both None") + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + topk = topk_ids.shape[1] + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + if has_no_zp: + w1_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + w2_zeros = torch.empty((0, 0), + dtype=hidden_states.dtype, + device=hidden_states.device, + requires_grad=False) + + if has_no_act_order: + g_idx1 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + g_idx2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + sort_indices1 = torch.empty((0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + sort_indices2 = torch.empty((0, 0), + dtype=torch.int32, + device=hidden_states.device, + requires_grad=False) + + scalar_type1 = get_scalar_type(num_bits, has_all_zp) + scalar_type2 = get_scalar_type(num_bits, has_all_zp) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + w1_zeros, + g_idx1, + sort_indices1, + workspace, + scalar_type1, + M, + 2 * N, + K, + is_k_full, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + w2_zeros, + g_idx2, + sort_indices2, + workspace, + scalar_type2, + M, + K, + N, + is_k_full, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 0000000..03ffc2f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,693 @@ +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + # max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + max_num_m_blocks = topk_ids.numel() + num_experts + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: + ops.invoke_fused_moe_kernel(A,B,C,A_scale,B_scale,topk_weights,topk_ids,sorted_token_ids,expert_ids,num_tokens_post_padded,mul_routed_weight,top_k,config,compute_type,use_fp8_w8a8,use_int8_w8a16) + return + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8: + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + elif use_int8_w8a16: + assert B_scale is not None + else: + assert A_scale is None + assert B_scale is None + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, + B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + ) + + +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: + device_name = current_platform.get_device_name().replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + + +@functools.lru_cache +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ("Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s"), config_file_path) + return None + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + numel = M * topk + if numel <= 64: + config['BLOCK_SIZE_M'] = 32 + elif numel <= 1024: + config['BLOCK_SIZE_M'] = 64 + else: + config['BLOCK_SIZE_M'] = 256 + return config + + +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + # configs = get_moe_configs(E, N, dtype) + configs = None + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin) + return config + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + M, _ = hidden_states.shape + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + + ops.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +def grouped_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def get_config_dtype_str(dtype: torch.dtype, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + if use_fp8_w8a8 or use_int8_w8a16: + raise NotImplementedError("fused_experts has not implemented fp8_w8a8 and int8_w8a16 yet.") + if a1_scale is not None or a2_scale is not None: + raise NotImplementedError("fused_experts has not implemented static_w8a8 yet.") + + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + dtype=hidden_states.dtype) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + config_dtype, + override_config=override_config, + ) + + config = get_config_func(M) + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + + invoke_fused_moe_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) + + torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx]) + return out_hidden_states + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) + + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py new file mode 100644 index 0000000..69ffb0f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -0,0 +1,717 @@ +from abc import abstractmethod +from enum import Enum +from typing import Callable, List, Optional, Tuple + +import torch + +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +# from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig +from vllm.model_executor.layers.quantization.compressed_tensors.utils import CompressionFormat, QuantizationStrategy +# from vllm.model_executor.layers.quantization.utils.w8a8_utils import create_per_channel_scale_param +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, x: torch.Tensor, + router_logits: torch.Tensor, top_k: int, renormalize: bool, + use_grouped_topk: bool) -> torch.Tensor: + raise NotImplementedError + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: + + return self.forward(x=x, + layer=layer, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_experts) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_experts(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True) + + def forward_cpu(self, *args, **kwargs): + raise NotImplementedError( + "The CPU backend currently does not support MoE.") + + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + assert custom_routing_function is None + return fused_moe(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk=top_k, + gating_output=router_logits, + renormalize=renormalize) + +class W8A8QuantizedFusedMoEMethod(FusedMoEMethodBase): + """MoE method W8A8 quantization. This class is for compressed-tensors format loading""" + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + self.strategy = extra_weight_attrs['quant_config'].target_scheme_map['Linear']['weights'].strategy + self.is_static_input_scheme = not extra_weight_attrs['quant_config'].target_scheme_map['Linear']['input_activations'].dynamic + # assert self.is_static_input_scheme, "W8A8 int quantization only support static input activation for now" + + self.quant_config = extra_weight_attrs["quant_config"] + self.weight_loader = extra_weight_attrs["weight_loader"] + + self.logical_widths_13 = [intermediate_size * 2] + self.logical_widths_2 = [intermediate_size * 2] + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=torch.int8), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + # WEIGHT SCALE + layer_kwargs = {"weight_loader": self.weight_loader} + if self.strategy == QuantizationStrategy.CHANNEL: + scale = torch.nn.Parameter(torch.empty((num_experts, intermediate_size * 2, 1), + dtype=torch.float32), + requires_grad=False) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"input_dim": 1, "output_dim": 0, **layer_kwargs}) + else: + assert self.strategy == QuantizationStrategy.TENSOR + scale = torch.nn.Parameter(torch.empty((num_experts, 2), dtype=torch.float32), + requires_grad=False) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, { + "needs_scalar_to_array": True, + **layer_kwargs + }) + set_weight_attrs(scale, {"is_int8_weight_scale": True}) + layer.register_parameter("w13_weight_scale", scale) + + + # INPUT SCALE + if self.is_static_input_scheme: + scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + set_weight_attrs(scale, { + "needs_scalar_to_array": True, + **layer_kwargs + }) + set_weight_attrs(scale, {"is_int8_input_scale": True}) + layer.register_parameter("w13_input_scale", scale) + + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=torch.int8), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + + set_weight_attrs(w2_weight, { + "input_dim": 1, + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + scale = torch.nn.Parameter(torch.empty((num_experts, hidden_size, 1), + dtype=torch.float32), + requires_grad=False) + scale[:] = torch.finfo(torch.float32).min + set_weight_attrs(scale, {"input_dim": 0, "output_dim": 1, **layer_kwargs}) + + else: + assert self.strategy == QuantizationStrategy.TENSOR + scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + set_weight_attrs(scale, { + "needs_scalar_to_array": True, + **layer_kwargs + }) + set_weight_attrs(scale, {"is_int8_weight_scale": True}) + layer.register_parameter("w2_weight_scale", scale) + + # INPUT SCALE + if self.is_static_input_scheme: + scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + set_weight_attrs(scale, { + "needs_scalar_to_array": True, + **layer_kwargs + }) + set_weight_attrs(scale, {"is_int8_input_scale": True}) + layer.register_parameter("w2_input_scale", scale) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + assert False, "fused_moe w8a8 use ixformer.contrib.vllm.layers.mixtral_decoder_layer_forward" + from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe + if self.is_static_input_scheme: + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + use_int8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + else: + return fused_moe(x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + use_int8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None) + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = (tp_size if tp_size is not None else + get_tensor_model_parallel_world_size()) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + self.custom_routing_function = custom_routing_function + + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod()) + elif (isinstance(quant_config, CompressedTensorsConfig) + and quant_config.quant_format == CompressionFormat.int_quantized.value + and quant_config.target_scheme_map['Linear']['input_activations'].num_bits == 8 + and quant_config.target_scheme_map['Linear']['weights'].num_bits == 8): + self.quant_method: Optional[QuantizeMethodBase] = ( + W8A8QuantizedFusedMoEMethod()) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + quant_config=quant_config) + + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # Load grouped weight scales for group quantization + # or model weights + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + + # compressed-tensors checkpoints with packed weights are stored flipped + # TODO (mgoin): check self.quant_method.quant_config.quant_format + # against known CompressionFormat enum values that have this quality + loaded_weight = loaded_weight.t().contiguous() if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsWNA16MoEMethod") else loaded_weight + + if shard_id not in ("w1", "w2", "w3"): + raise ValueError(f"shard_id must be ['w1','w2','w3'] but " + f"got {shard_id}.") + + # Special case for fp8 scales. + if getattr(param, "is_fp8_scale", False): + self._load_fp8_scale(param.data, loaded_weight, weight_name, + shard_id, expert_id) + return + elif getattr(param, "is_int8_input_scale", False): + self._load_int8_input_scale(param.data, loaded_weight, + shard_id, expert_id) + return + + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} + + expert_data = param.data[expert_id] + tp_rank = get_tensor_model_parallel_rank() + + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + shard_dim = ~shard_dim + + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + # this is needed for compressed-tensors only + loaded_weight = loaded_weight.to(param.data.device) + + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + # Case weight scales and zero_points + if ("scale" in weight_name or "zero" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", FusedMoeWeightScaleSupported.CHANNEL.value) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return + + # Case weight_shape + if "weight_shape" in weight_name: + # only required by compressed-tensors + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + @staticmethod + def select_experts(hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None): + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, grouped_topk) + + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) + + return topk_weights, topk_ids + + def forward(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, ckpt_gate_proj_name: str, ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int) -> List[Tuple[str, str, int, str]]: + + return [ + # (param_name, weight_name, expert_id, shard_id) + ("experts.w13_" if weight_name + in [ckpt_gate_proj_name, ckpt_up_proj_name] else "experts.w2_", + f"experts.{expert_id}.{weight_name}.", expert_id, shard_id) + for expert_id in range(num_experts) for shard_id, weight_name in [ + ("w1", ckpt_gate_proj_name), + ("w2", ckpt_down_proj_name), + ("w3", ckpt_up_proj_name), + ] + ] + + def _load_fp8_scale(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, weight_name: str, + shard_id: str, expert_id: int) -> None: + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if param_data[expert_id] != 1 and (param_data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}") + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + else: + param_data[expert_id] = loaded_weight + + def _load_int8_input_scale(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + shard_id: str, expert_id: int) -> None: + param_data = param.data + # Input scales can be loaded directly and should be equal. + if param_data[expert_id] != 1 and (param_data[expert_id].to(loaded_weight.device) - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}") + param_data[expert_id] = loaded_weight \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py new file mode 100644 index 0000000..563ee18 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -0,0 +1,62 @@ +import torch +import torch.nn.functional as F +from torch_xla.experimental.custom_kernel import _histogram + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +) -> torch.Tensor: + """ + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + device = hidden_states.device + dtype = hidden_states.dtype + assert (num_tokens * topk) % 16 == 0, ( + "The Pallas GMM kernel requires num_tokens * topk to be a multiple of " + f"16 but got {num_tokens * topk}") + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, topk_indices = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + topk_indices = topk_indices.flatten() + topk_argsort_indices = topk_indices.argsort() + topk_argsort_revert_indices = topk_argsort_indices.argsort() + token_indices = torch.arange(num_tokens, + device=device).repeat_interleave(topk) + token_indices = token_indices[topk_argsort_indices] + group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1) + + # NOTE(woosuk): The GMM Pallas kernel requires a different weight layout + # from HF Transformers. + w1 = w1.transpose(1, 2) + w2 = w2.transpose(1, 2) + + x = hidden_states[token_indices] + x = torch.ops.xla.gmm(x, w1, group_sizes) + x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:] + x = torch.ops.xla.gmm(x, w2, group_sizes) + x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) + + x = x * topk_weights.unsqueeze_(dim=-1) + x = x.sum(dim=-2) + x = x.reshape(orig_shape) + return x diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py new file mode 100644 index 0000000..eb5be2d --- /dev/null +++ b/vllm/model_executor/layers/layernorm.py @@ -0,0 +1,187 @@ +"""Custom normalization layers.""" +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.model_executor.custom_op import CustomOp + + +class RMSNorm(CustomOp): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.variance_epsilon = eps + self.variance_size_override = (None if var_hidden_size == hidden_size + else var_hidden_size) + + self.weight = nn.Parameter(torch.ones(hidden_size)) + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError("Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}") + + if self.variance_size_override is None: + x_var = x + else: + if hidden_size < self.variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{self.variance_size_override}, but found: {hidden_size}") + + x_var = x[:, :, :self.variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) + + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + residual_alpha: Optional[float] = 1.0, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + from vllm import _custom_ops as ops + + if residual is not None: + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + residual_alpha, + ) + return x, residual + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out + + def forward_xpu( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) + + from vllm._ipex_ops import ipex_ops as ops + + if residual is not None: + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + ) + return x, residual + return ops.rms_norm( + x, + self.weight.data, + self.variance_epsilon, + ) + + def extra_repr(self) -> str: + s = f"hidden_size={self.weight.data.size(0)}" + s += f", eps={self.variance_epsilon}" + return s + + +class GemmaRMSNorm(CustomOp): + """RMS normalization for Gemma. + + Two differences from the above RMSNorm: + 1. x * (1 + w) instead of x * w. + 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + @staticmethod + def forward_static( + weight: torch.Tensor, + variance_epsilon: float, + x: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + if residual is not None: + x = x + residual + residual = x + + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + variance_epsilon) + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + x = x * (1.0 + weight.float()) + x = x.to(orig_dtype) + return x if residual is None else (x, residual) + + def forward_native( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + return self.forward_static(self.weight.data, self.variance_epsilon, x, + residual) + + def forward_cuda( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # if torch.compiler.is_compiling(): + # return self.forward_native(x, residual) + + # if not getattr(self, "_is_compiled", False): + # self.forward_static = torch.compile( # type: ignore + # self.forward_static) + # self._is_compiled = True + return self.forward_native(x, residual) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py new file mode 100644 index 0000000..8ed74ef --- /dev/null +++ b/vllm/model_executor/layers/linear.py @@ -0,0 +1,1103 @@ +from abc import abstractmethod +from typing import Dict, List, Optional, Tuple + +import torch +#import ixformer.inference.functions as F +import ixformer.functions as F +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.parameter import (BasevLLMParameter, + PackedColumnParameter, + PackedvLLMParameter, + PerTensorScaleParameter, + RowvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + +WEIGHT_LOADER_V2_SUPPORTED = [ + "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", + "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", + "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod" +] + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def adjust_bitsandbytes_4bit_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: + """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" + + total, _ = qkv_offsets["total"] + orig_offset, orig_size = qkv_offsets[loaded_shard_id] + + quantized_total = param.data.shape[0] + quantized_offset = orig_offset * quantized_total // total + quantized_size = orig_size * quantized_total // total + + return quantized_size, quantized_offset + + +def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): + """For fused modules (QKV and MLP) we have an array of length + N that holds 1 scale for each "logical" matrix. So the param + is an array of length N. The loaded_weight corresponds to + one of the shards on disk. Here, we slice the param based on + the shard_id for loading. + """ + qkv_idxs = {"q": 0, "k": 1, "v": 2} + + if isinstance(shard_id, str): + shard_id = qkv_idxs[shard_id] + elif not isinstance(shard_id, int): + raise ValueError(f"Unknown Shard Id {shard_id}") + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + return param[shard_id], loaded_weight + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization.""" + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if (x.shape[0] == 16384 or x.shape[0] == 15360): + if bias is None: + return x @ layer.weight.T + else: + return x @ layer.weight.T + bias + return F.linear(x, layer.weight, bias) + + +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, + prefix=prefix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix=prefix) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + # If the weight on disk does not have a shape, give it one + # (such scales for AutoFp8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + return s + + +class ColumnParallelLinear(LinearBase): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + output_sizes: list of output sizes packed into one output, like for QKV + the list would be size 3. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[List[int]] = None, + prefix: str = ""): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config, prefix) + + self.gather_output = gather_output + + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + assert self.quant_method is not None + self.output_size_per_partition = divide(self.output_size, tp_size) + self.output_partition_sizes = [self.output_size_per_partition] + # If QKV or MergedColumn, use output size of each partition. + if hasattr(self, "output_sizes"): + self.output_partition_sizes = [ + divide(output_size, tp_size) + for output_size in self.output_sizes + ] + + if output_sizes is None: + output_sizes = [output_size] + + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + + param_data = param.data + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if output_dim is not None and not use_bitsandbytes_4bit: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + param.load_column_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + def extra_repr(self) -> str: + s = f"in_features={self.input_size}" + s += f", output_features={self.output_size_per_partition}" + s += f", bias={self.bias is not None}" + s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", gather_output={self.gather_output}" + return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + return + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 2: + self.qweight = param.materialize_nested() + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + # Special case for per-tensor scale to load scalar into fused array. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where MLP layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + + current_shard_offset = 0 + shard_offsets: List[Tuple[int, int, int]] = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + if loaded_shard_id is None: + if isinstance(param, PerTensorScaleParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_merged_column_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id < len(self.output_sizes) + + tp_size = get_tensor_model_parallel_world_size() + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + + param.load_merged_column_weight(loaded_weight=loaded_weight, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + prefix: The name of the layer in the state dict, including all parents + (e.g. model.layers.0.qkv_proj) + """ + + def __init__(self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] + + super().__init__(input_size=input_size, + output_size=output_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + + def _get_shard_offset_mapping(self, loaded_shard_id: str): + shard_offset_mapping = { + "q": 0, + "k": self.num_heads * self.head_size, + "v": (self.num_heads + self.num_kv_heads) * self.head_size, + "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size + } + return shard_offset_mapping.get(loaded_shard_id) + + def _get_shard_size_mapping(self, loaded_shard_id: str): + shard_size_mapping = { + "q": self.num_heads * self.head_size, + "k": self.num_kv_heads * self.head_size, + "v": self.num_kv_heads * self.head_size, + } + return shard_size_mapping.get(loaded_shard_id) + + def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + """ + Handle special case for models where QKV layers are already + fused on disk. In this case, we have no shard id. This function + determmines the shard id by splitting these layers and then calls + the weight loader using the shard id. + + An example of a model with these fused layers: + https://huggingface.co/microsoft/Phi-3-mini-4k-instruct + """ + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", + (self.total_num_heads + self.total_num_kv_heads) * self.head_size, + self.total_num_kv_heads * self.head_size), + ] + + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: + shard_size, shard_offset = \ + param.adjust_shard_indexes_for_packing( + shard_size=shard_size, shard_offset=shard_offset) + + loaded_weight_shard = loaded_weight.narrow(param.output_dim, + shard_offset, + shard_size) + self.weight_loader_v2(param, loaded_weight_shard, shard_id) + + def weight_loader_v2(self, + param: BasevLLMParameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + if loaded_shard_id is None: # special case for certain models + if isinstance(param, PerTensorScaleParameter): + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) + return + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) + return + # TODO: @dsikka - move to parameter.py + self._load_fused_module_from_checkpoint(param, loaded_weight) + return + + assert loaded_shard_id in ["q", "k", "v"] + + shard_offset = self._get_shard_offset_mapping(loaded_shard_id) + shard_size = self._get_shard_size_mapping(loaded_shard_id) + + param.load_qkv_weight(loaded_weight=loaded_weight, + num_heads=self.num_kv_head_replicas, + shard_id=loaded_shard_id, + shard_offset=shard_offset, + shard_size=shard_size) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + + # Special case for GGUF + # initialize GGUF param after we know the quantize type + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type and loaded_shard_id is not None: + idx_map = {"q": 0, "k": 1, "v": 2} + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + return + + if is_gguf_weight: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + output_dim = getattr(param, "output_dim", None) + shard_size = loaded_weight.size(output_dim) // tp_size + start_idx = tp_rank * shard_size + + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 3: + self.qweight = param.materialize_nested() + return + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + # Special case for per-tensor scales in fused case. + needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) + + if loaded_shard_id is None: + # Loaded weight is already fused on disk (qkv/mlp). + if output_dim is None: + if needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, 0) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.total_num_heads * self.head_size), + "k": (self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + "v": + ((self.total_num_heads + self.total_num_kv_heads) * + self.head_size, + self.total_num_kv_heads * self.head_size), + "total": + ((self.total_num_heads + 2 * self.total_num_kv_heads) * + self.head_size, 0) + } + + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, shard_id) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + + # If output dim is defined, use the default loading process. + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # Special case for Quantized Weights. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: + orig_qkv_offsets = { + "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) + } + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, loaded_shard_id) + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if not use_bitsandbytes_4bit: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + + # Special case for for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_index = ["q", "k", "v"].index(loaded_shard_id) + param_data = param_data.narrow(0, shard_index * shard_size, + shard_size) + # Special case for per-tensor scales in fused case. + elif needs_scalar_to_array: + param_data, loaded_weight = adjust_scalar_to_fused_array( + param_data, loaded_weight, loaded_shard_id) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(LinearBase): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__(self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config, prefix) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + # Divide the weight matrix along the last dimension. + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=[self.output_size], + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + input_dim = getattr(param, "input_dim", None) + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) + + # Special case for GGUF + is_gguf_weight = getattr(param, "is_gguf_weight", False) + is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) + if is_gguf_weight_type: + param.weight_type = loaded_weight.item() + + # Materialize GGUF UninitializedParameter + if is_gguf_weight and isinstance(param, UninitializedParameter): + weight_shape = list(loaded_weight.shape) + if input_dim: + weight_shape[input_dim] = weight_shape[input_dim] // tp_size + param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) + + param_data = param.data + # bitsandbytes loads the weights of the specific portion + # no need to narrow here + if input_dim is not None and not use_bitsandbytes_4bit: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def weight_loader_v2(self, param: BasevLLMParameter, + loaded_weight: torch.Tensor): + + # Special case for loading scales off disk, which often do not + # have a shape (such as in the case of AutoFP8). + if len(loaded_weight.shape) == 0: + assert loaded_weight.numel() == 1 + loaded_weight = loaded_weight.reshape(1) + + param.load_row_parallel_weight(loaded_weight=loaded_weight) + + def forward(self, input_): + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + return output, output_bias + + def extra_repr(self) -> str: + s = f"input_features={self.input_size_per_partition}" + s += f", output_features={self.output_size}" + s += f", bias={self.bias is not None}" + s += f", tp_size={self.tp_size}" + s += f", reduce_results={self.reduce_results}" + return s diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py new file mode 100644 index 0000000..e985c1b --- /dev/null +++ b/vllm/model_executor/layers/logits_processor.py @@ -0,0 +1,156 @@ +"""A layer that compute logits from hidden_stats.""" +import inspect +from typing import Optional + +import torch +import torch.nn as nn + +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_gather) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform + + +class LogitsProcessor(nn.Module): + """Process logits and apply logits processors from sampling metadata. + + This layer does the following: + 1. Gather logits from model hidden_states. + 2. Scale logits if needed. + 3. Apply logits processors (if any). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None, + scale: float = 1.0, + logits_as_input: bool = False, + soft_cap: Optional[float] = None) -> None: + """ + Args: + scale: A scaling factor to apply to the logits. + """ + super().__init__() + self.scale = scale + self.vocab_size = vocab_size + # Whether the input is logits (default is hidden states). + self.logits_as_input = logits_as_input + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + # Soft cap the logits. Used in Gemma 2. + self.soft_cap = soft_cap + # Whether to use gather or all-gather to gather the logits. + self.use_gather = not current_platform.is_tpu() + + def forward( + self, + lm_head: VocabParallelEmbedding, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + if self.logits_as_input: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + if hidden_states.shape[0] > 0: + logits = self._get_logits(hidden_states, lm_head, embedding_bias) + else: + logits = torch.empty([0, lm_head.weight.shape[0]], device=hidden_states.device, dtype=hidden_states.dtype) + if logits is not None: + if self.soft_cap is not None: + logits = logits / self.soft_cap + logits = torch.tanh(logits) + logits = logits * self.soft_cap + + if self.scale != 1.0: + logits *= self.scale + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + return logits + + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: + # Get the logits for the next tokens. + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + if self.use_gather: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + else: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[..., :self.org_vocab_size] + return logits + + def extra_repr(self) -> str: + s = f"vocab_size={self.vocab_size}" + s += f", forg_vocab_size={self.org_vocab_size}" + s += f", scale={self.scale}, logits_as_input={self.logits_as_input}" + return s + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + found_logits_processors = False + logits_processed = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + + for seq_id, logits_row_idx in zip(seq_ids, + seq_group.sample_indices): + logits_row = logits[logits_row_idx] + past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids + prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids + + for logits_processor in logits_processors: + parameters = inspect.signature(logits_processor).parameters + if len(parameters) == 3: + logits_row = logits_processor(prompt_tokens_ids, + past_tokens_ids, + logits_row) + else: + logits_row = logits_processor(past_tokens_ids, + logits_row) + + logits[logits_row_idx] = logits_row + + logits_processed += len(seq_group.sample_indices) + len( + seq_group.prompt_logprob_indices) + + if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly + assert logits_processed == logits.shape[0] + return logits diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/mamba/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/mamba/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..888ccc9 Binary files /dev/null and b/vllm/model_executor/layers/mamba/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/mamba/ops/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/mamba/ops/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..a095ccb Binary files /dev/null and b/vllm/model_executor/layers/mamba/ops/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/mamba/ops/__pycache__/causal_conv1d.cpython-310.pyc b/vllm/model_executor/layers/mamba/ops/__pycache__/causal_conv1d.cpython-310.pyc new file mode 100644 index 0000000..5b34390 Binary files /dev/null and b/vllm/model_executor/layers/mamba/ops/__pycache__/causal_conv1d.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/mamba/ops/__pycache__/mamba_ssm.cpython-310.pyc b/vllm/model_executor/layers/mamba/ops/__pycache__/mamba_ssm.cpython-310.pyc new file mode 100644 index 0000000..0613b1e Binary files /dev/null and b/vllm/model_executor/layers/mamba/ops/__pycache__/mamba_ssm.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py new file mode 100644 index 0000000..ed7241a --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out = ops.causal_conv1d_fwd(x, weight, bias, conv_states, query_start_loc, + cache_indices, has_initial_state, activation + in ["silu", "swish"]) + return out + + +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = ops.causal_conv1d_update(x, conv_state, weight, bias, activation_val, + cache_seqlens, conv_state_indices) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 0000000..08b016c --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,395 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py + +from typing import Tuple + +import torch +import triton +import triton.language as tl +from packaging import version + +from vllm import _custom_ops as ops + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics({ + "HAS_STATE_BATCH_INDICES": + lambda args: args["state_batch_indices_ptr"] is not None +}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + state_ptr += (state_batch_idx * stride_state_batch + + pid_h * stride_state_head) + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch, ) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn( + u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + u: (dim, total_length) for varlen or (batch, dim, seqlen) + delta: (dim, total_length) for varlen or (batch, dim, seqlen) + A: (dim, dstate) + B: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + C: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + D: (dim,) + z: (dim, total_length) for varlen or (batch, dim, seqlen) + dt_bias: (dim,) or (dim) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended with 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + A tensor with each cell is a correspondent + input and output ssm_state index + has_initial_state: (batch) bool + A tensor populated with ones and zeros, + indicate if the ssm_state at the corresponding index should be + used as initial state. Not providing argument assumes + there's no initial state + + returns + output: (dim, total_length) for varlen or (batch, dim, seqlen) + supports inplace replacement + last_state has shape (batch, dim, dstate). + supports inplace replacement if ssm_state was provided + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3 and query_start_loc is None: + B = B.unsqueeze(1) + if B.dim() == 2 and query_start_loc is not None: + B = B.unsqueeze(0) + if C.dim() == 3 and query_start_loc is None: + C = C.unsqueeze(1) + if C.dim() == 2 and query_start_loc is not None: + C = C.unsqueeze(0) + + ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus, + query_start_loc, cache_indices, has_initial_state, + ssm_states) + + if z is None: + return delta # output written inplace to delta + else: + return z # output written inplace to z diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py new file mode 100644 index 0000000..76ccb3d --- /dev/null +++ b/vllm/model_executor/layers/pooler.py @@ -0,0 +1,63 @@ +from enum import IntEnum + +import torch +import torch.nn as nn + +from vllm.model_executor.pooling_metadata import (PoolingMetadata, + PoolingTensors) +from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput + + +class PoolingType(IntEnum): + """Enumeration for different types of pooling methods.""" + LAST = 0 + ALL = 1 + + +class Pooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + normalize: Whether to normalize the pooled data. + """ + + def __init__(self, pooling_type: PoolingType, normalize: bool): + super().__init__() + self.pooling_type = pooling_type + self.normalize = normalize + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools specific information from hidden states based on metadata.""" + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + if self.pooling_type == PoolingType.LAST: + last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 + pooled_data = hidden_states[last_token_flat_indices] + elif self.pooling_type == PoolingType.ALL: + offset = 0 + pooled_data = [] + for prompt_len in prompt_lens: + pooled_data.append(hidden_states[offset:offset + prompt_len]) + offset += prompt_len + else: + raise ValueError(f"Invalid pooling type: {self.pooling_type}") + + if self.normalize: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data + ] + + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py new file mode 100644 index 0000000..fa51d42 --- /dev/null +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -0,0 +1,69 @@ +from typing import Dict, Type + +from vllm.model_executor.layers.quantization.aqlm import AQLMConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.bitsandbytes import ( + BitsAndBytesConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig) +from vllm.model_executor.layers.quantization.experts_int8 import ( + ExpertsInt8Config) +from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.gguf import GGUFConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQMarlin24Config) +from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig +from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config +from vllm.model_executor.layers.quantization.neuron_quant import ( + NeuronQuantConfig) +from vllm.model_executor.layers.quantization.qqq import QQQConfig +from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig +from vllm.model_executor.layers.quantization.w8a16 import W8a16Config + +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + # "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, + "w8a16": W8a16Config, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in QUANTIZATION_METHODS: + raise ValueError(f"Invalid quantization method: {quantization}") + return QUANTIZATION_METHODS[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "QUANTIZATION_METHODS", +] diff --git a/vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..41b1bb6 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/aqlm.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/aqlm.cpython-310.pyc new file mode 100644 index 0000000..970f5d3 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/aqlm.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/awq.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/awq.cpython-310.pyc new file mode 100644 index 0000000..c6291af Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/awq.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-310.pyc new file mode 100644 index 0000000..47ed177 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/awq_triton.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/awq_triton.cpython-310.pyc new file mode 100644 index 0000000..7f1d394 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/awq_triton.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/base_config.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/base_config.cpython-310.pyc new file mode 100644 index 0000000..0759d0c Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/base_config.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/bitsandbytes.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/bitsandbytes.cpython-310.pyc new file mode 100644 index 0000000..a4f4198 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/bitsandbytes.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/deepspeedfp.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/deepspeedfp.cpython-310.pyc new file mode 100644 index 0000000..6213ec2 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/deepspeedfp.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/experts_int8.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/experts_int8.cpython-310.pyc new file mode 100644 index 0000000..ea4f3a1 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/experts_int8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/fbgemm_fp8.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/fbgemm_fp8.cpython-310.pyc new file mode 100644 index 0000000..ac659f7 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/fbgemm_fp8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/fp8.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/fp8.cpython-310.pyc new file mode 100644 index 0000000..e60b9d0 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/fp8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/gguf.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/gguf.cpython-310.pyc new file mode 100644 index 0000000..63e7832 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/gguf.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/gptq.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/gptq.cpython-310.pyc new file mode 100644 index 0000000..7be428a Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/gptq.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin.cpython-310.pyc new file mode 100644 index 0000000..3e2ccc9 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin_24.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin_24.cpython-310.pyc new file mode 100644 index 0000000..7041cbc Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin_24.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/ipex_quant.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/ipex_quant.cpython-310.pyc new file mode 100644 index 0000000..f3d9348 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/ipex_quant.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/kv_cache.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/kv_cache.cpython-310.pyc new file mode 100644 index 0000000..e9a9e83 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/kv_cache.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/marlin.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/marlin.cpython-310.pyc new file mode 100644 index 0000000..326ec50 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/marlin.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/modelopt.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/modelopt.cpython-310.pyc new file mode 100644 index 0000000..acafb63 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/modelopt.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/neuron_quant.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/neuron_quant.cpython-310.pyc new file mode 100644 index 0000000..8e7cd18 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/neuron_quant.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/qqq.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/qqq.cpython-310.pyc new file mode 100644 index 0000000..0993354 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/qqq.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/schema.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/schema.cpython-310.pyc new file mode 100644 index 0000000..b0cf65f Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/schema.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/tpu_int8.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/tpu_int8.cpython-310.pyc new file mode 100644 index 0000000..39b8434 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/tpu_int8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/__pycache__/w8a16.cpython-310.pyc b/vllm/model_executor/layers/quantization/__pycache__/w8a16.cpython-310.pyc new file mode 100644 index 0000000..cc44e16 Binary files /dev/null and b/vllm/model_executor/layers/quantization/__pycache__/w8a16.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py new file mode 100644 index 0000000..c35eac1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -0,0 +1,374 @@ +# Supports AQLM compression, see https://github.com/Vahe1994/AQLM +# and https://arxiv.org/pdf/2401.06118.pdf + +import math +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +def get_int_dtype(nbits: int) -> torch.dtype: + if nbits <= 8: + return torch.int8 + if nbits <= 16: + return torch.int16 + if nbits <= 32: + return torch.int32 + if nbits <= 64: + return torch.int64 + raise ValueError(f"No dtype available for {nbits}-bit codebooks") + + +@torch.inference_mode() +def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor: + return data.to(torch.int64) % (2**nbits) + + +def dequantize_weight(codes: torch.Tensor, + codebooks: torch.Tensor, + scales: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Decode float weights from quantization codes. Differentiable. + :param codes: tensor of integer quantization codes, shape + [*dims, num_out_groups, num_in_groups, num_codebooks] + :param codebooks: tensor of vectors for each quantization code, + [num_codebooks, codebook_size, out_group_size, in_group_size] + :param scales: weight will be multiplied by this factor, must be + broadcastble with + [*dims, out_groups, num_in_groups, out_group_size, in_group_size] + :return: reconstructed weight tensor of shape + [*dims, num_in_groups*group_size] + """ + num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:] + num_codebooks, codebook_size, out_group_size, in_group_size = \ + codebooks.shape + out_features = num_out_groups * out_group_size + in_features = num_in_groups * in_group_size + codebook_offsets = torch.arange( + 0, num_codebooks * codebook_size, codebook_size, + device=codes.device) # shape: [num_codebooks] + reconstructed_weight_flat = F.embedding_bag( + codes.flatten(0, -2) + codebook_offsets, + codebooks.flatten(0, 1).flatten(-2, -1), + mode="sum" + ) # [prod(dims) * num_out_groups * num_in_groups, out_group_size + # * in_group_size] + + reconstructed_weight_groupwise = reconstructed_weight_flat.view( + list(codes.shape[:-3]) + + [num_out_groups, num_in_groups, out_group_size, in_group_size]) + if scales is not None: + reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul( + scales) + return reconstructed_weight_groupwise.swapaxes( + -3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features]) + + +def dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + bias: Optional[torch.Tensor], +) -> torch.Tensor: + dequantized_weight = dequantize_weight( + unpack_int_data(codes, codebooks.shape[1].bit_length() - 1), + codebooks, + scales, + ) + return F.linear(input, dequantized_weight, bias) + + +# Generic dequantization, slow but flexible. +def generic_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: List[int], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + output_shape = input.shape[:-1] + (scales.shape[0], ) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + num_outputs = len(output_partition_sizes) + + # break the inputs and codebooks apart then combine the outputs. + # Surprisingly (to me) this is faster than doing 3 de-quants and 1 big + # multiply at the end. + num_codebooks = codebooks.shape[0] // num_outputs + assert (scales.shape[0] == codes.shape[0]) + assert (sum(output_partition_sizes) == scales.shape[0]) + output_offset = 0 + codebooks_offset = 0 + for output_size in output_partition_sizes: + shard_output = dequantize_gemm( + input, codes.narrow(0, output_offset, output_size), + codebooks.narrow(0, codebooks_offset, num_codebooks), + scales.narrow(0, output_offset, output_size), None + if bias is None else bias.narrow(0, output_offset, output_size)) + + output_slice = output.narrow(-1, output_offset, output_size) + assert (output_slice.shape == shard_output.shape) + output_slice.copy_(shard_output) + output_offset += output_size + codebooks_offset += num_codebooks + return output + + +# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8 +# at 6 and 9 times faster than the generic version above, respectively. +def optimized_dequantize_gemm( + input: torch.Tensor, # [..., in_features] + codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] + codebooks: torch. + Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] + scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + output_partition_sizes: List[int], + bias: Optional[torch.Tensor], +) -> torch.Tensor: + weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) + + if bias is None: + # scaling the output is fastest, so we do that when possible. + output = F.linear(input, weights, bias) + orig_shape = output.shape + flattened_output = output.view(-1, output.size(-1)) + f_scales = scales.view(-1, scales.shape[0]) + b_scales = f_scales.expand(flattened_output.shape[0], -1) + flattened_output *= b_scales + return output.view(orig_shape) + else: + b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( + -1, weights.shape[1]) + weights *= b_scales + return F.linear(input, weights, bias) + + +class AQLMConfig(QuantizationConfig): + """Config class for AQLM. + + Reference: https://github.com/Vahe1994/AQLM + """ + + def __init__( + self, + in_group_size: int, + nbits_per_codebook: int, + num_codebooks: int, + out_group_size: int, + ) -> None: + self.in_group_size = in_group_size + self.nbits_per_codebook = nbits_per_codebook + self.num_codebooks = num_codebooks + self.out_group_size = out_group_size + + # out_group_size > 1 is untested, and probably won't work as-is. + assert (self.out_group_size == 1) + self.pack_factor = (self.in_group_size * self.out_group_size) + + def __repr__(self) -> str: + return (f"AQLMConfig(in_group_size={self.in_group_size}, " + f"nbits_per_codebook={self.nbits_per_codebook}, " + f"num_codebooks={self.num_codebooks}, " + f"out_group_size={self.out_group_size})") + + @classmethod + def get_name(cls) -> str: + return "aqlm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": + in_group_size = cls.get_from_keys(config, ["in_group_size"]) + nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"]) + num_code_books = cls.get_from_keys(config, ["num_codebooks"]) + out_group_size = cls.get_from_keys(config, ["out_group_size"]) + return cls(in_group_size, nbits_per_codebook, num_code_books, + out_group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AQLMLinearMethod"]: + if isinstance(layer, LinearBase): + return AQLMLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class AQLMLinearMethod(LinearMethodBase): + """Linear method for AQLM. + + Args: + quant_config: The AQLM quantization config. + """ + + def __init__(self, quant_config: AQLMConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + del output_size # Unused. + del input_size # Unused. + + if params_dtype != torch.half: + raise ValueError("Only half is currently supported by aqlm") + if input_size_per_partition % self.quant_config.in_group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.out_group_size != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + codes = Parameter( + torch.empty( + # There could actually be two pack factors, one along input and + # one along output, but we don't currently support + # out_group_size, and only the one along output needs to be + # marked with "packed_dim" in order for QKVLinear to work. + output_size_per_partition, + input_size_per_partition // self.quant_config.pack_factor, + self.quant_config.num_codebooks, + dtype=get_int_dtype(self.quant_config.nbits_per_codebook), + ), + requires_grad=False, + ) + + set_weight_attrs( + codes, + { + "input_dim": 1, + "output_dim": 0, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }, + ) + + codebooks = Parameter( + torch.empty( + self.quant_config.num_codebooks * len(output_partition_sizes), + 2**self.quant_config.nbits_per_codebook, + self.quant_config.out_group_size, + self.quant_config.in_group_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + codebooks, + { + # metadata indicates fixed size concatenated along dim 0 + "is_metadata": True, + "output_partition_sizes": output_partition_sizes + }, + ) + + scales = Parameter( + torch.empty( + ( + output_size_per_partition // + self.quant_config.out_group_size, + 1, + 1, + 1, + ), + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "output_dim": 0, + "packed_dim": 0, + "pack_factor": self.quant_config.out_group_size + }, + ) + + layer.register_parameter("codes", codes) + set_weight_attrs(codes, extra_weight_attrs) + layer.register_parameter("codebooks", codebooks) + set_weight_attrs(codebooks, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + codebooks = layer.codebooks + codes = layer.codes + scales = layer.scales + output_partition_sizes = getattr(codebooks, "output_partition_sizes", + []) + + nbooks = codes.shape[2] + ingroups = codebooks.shape[3] + outgroups = codebooks.shape[2] + bits = codebooks.shape[1] + + # We support these formats with dedicated gemm and decompression + # kernels. + if ingroups == 8 and outgroups == 1 and ( + (bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)): + + # thresholds determined by timings on an A6000, one GPU + use_gemv = math.prod(x.shape[:-1]) <= 6 + + return ops.aqlm_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) if use_gemv else optimized_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) + + # fall back all unoptimized formats + return generic_dequantize_gemm( + x, + codes, + codebooks, + scales, + output_partition_sizes, + bias, + ) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py new file mode 100644 index 0000000..8849e4b --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq.py @@ -0,0 +1,173 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point})") + + def get_name(self) -> str: + return "awq" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + return cls(weight_bits, group_size, zero_point) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["AWQLinearMethod"]: + if isinstance(layer, LinearBase): + return AWQLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + +class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + weight_loader = extra_weight_attrs.get("weight_loader") + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + qzeros = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + + # num_tokens >= threshold + # FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 + FP16_MATMUL_HEURISTIC_CONDITION = False + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + out = torch.matmul(reshaped_x, out) + else: + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor, group_size=self.quant_config.group_size) + if bias is not None: + out.add_(bias) + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py new file mode 100644 index 0000000..b3d93b2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -0,0 +1,464 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, moe_awq_to_marlin_zero_points, + verify_marlin_supported, verify_marlin_supports_shape) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class AWQMarlinConfig(QuantizationConfig): + """Config class for AWQ Marlin""" + + # num_bits -> type + TYPE_MAP = { + 4: scalar_types.uint4, + 8: scalar_types.uint8, + } + + def __init__(self, weight_bits: int, group_size: int, has_zp: bool, + lm_head_quantized: bool) -> None: + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.has_zp = has_zp + self.lm_head_quantized = lm_head_quantized + self.weight_bits = weight_bits + + if self.weight_bits not in self.TYPE_MAP: + raise ValueError(f"Unsupported num_bits = {self.weight_bits}. " + f"Supported num_bits = {self.TYPE_MAP.keys()}") + + self.quant_type = self.TYPE_MAP[self.weight_bits] + + verify_marlin_supported(self.quant_type, + group_size=self.group_size, + has_zp=self.has_zp) + + def __repr__(self) -> str: + return (f"AWQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"has_zp={self.has_zp}, " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> str: + return "awq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + has_zp = cls.get_from_keys(config, ["zero_point"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, has_zp, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "awq_marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "awq": + logger.info("Detected that the model can run with awq_marlin" + ", however you specified quantization=awq explicitly," + " so forcing awq. Use quantization=awq_marlin for" + " faster inference") + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return AWQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return AWQMoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + has_zp = quant_config.get("zero_point") + + if not current_platform.is_cuda(): + return False + + if quant_method != "awq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or has_zp is None): + return False + + if num_bits not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[num_bits], + group_size=group_size, + has_zp=has_zp) + + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ Marlin. + + Args: + quant_config: The AWQ Marlin quantization config. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size) + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + num_groups = input_size_per_partition // group_size + + qzeros = PackedvLLMParameter( + data=torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + scales = GroupQuantScaleParameter(data=torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + # TODO: Update this docs + # Checkpoints are serialized in AutoAWQ format, which is different from the + # marlin format. This function is called after the weights are loaded. + # Here, we handle the repacking + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = layer.qweight.device + layer.qweight = torch.nn.Parameter(layer.qweight.data, + requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, + requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, + requires_grad=False) + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace( + layer.output_size_per_partition, device) + + # Repack weights from AWQ format to marlin format. + marlin_qweight = ops.awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits) + replace_parameter(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to marlin format. + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size) + replace_parameter(layer, "scales", marlin_scales) + + # Permute zero-points from AWQ format to marlin format. + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits) + replace_parameter(layer, "qzeros", marlin_zp) + + # Not-used + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_awq_marlin_linear( + input=x, + weight=layer.qweight, + weight_scale=layer.scales, + weight_zp=layer.qzeros, + g_idx=layer.g_idx, + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=layer.workspace, + quant_type=self.quant_config.quant_type, + output_size_per_partition=layer.output_size_per_partition, + input_size_per_partition=layer.input_size_per_partition, + bias=bias) + + +class AWQMoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: AWQMarlinConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + extra_weight_attrs.update({ + "is_transposed": + True, + "quant_method": + FusedMoeWeightScaleSupported.GROUP.value, + }) + + w13_qweight = Parameter(torch.empty(num_experts, + hidden_size, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + + w2_qweight = Parameter(torch.empty(num_experts, + intermediate_size, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + + num_groups_w13 = hidden_size // self.quant_config.group_size + num_groups_w2 = intermediate_size // self.quant_config.group_size + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + w13_scales = Parameter(torch.empty(num_experts, + num_groups_w13, + intermediate_size * 2, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + + w2_scales = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + + # WEIGHT_ZERO_POINT + # Allocate 2 zero points for w1 and w3 respectively. + w13_qzeros = Parameter(torch.empty(num_experts, + num_groups_w13, + 2 * intermediate_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + + w2_qzeros = Parameter(torch.empty(num_experts, + num_groups_w2, + hidden_size // + self.quant_config.pack_factor, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + num_experts = layer.w13_qweight.shape[0] + device = layer.w13_qweight.device + + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.awq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + size_k=layer.w13_qweight.shape[1], + size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + + marlin_w2_qweight = ops.awq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + size_k=layer.w2_qweight.shape[1], + size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + + # Why does this take the intermediate size for size_k? + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + + replace_parameter(layer, "w13_scales", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + marlin_w13_zp = moe_awq_to_marlin_zero_points( + layer.w13_qzeros, + size_k=layer.w13_qzeros.shape[1], + size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w13_qzeros", marlin_w13_zp) + + marlin_w2_zp = moe_awq_to_marlin_zero_points( + layer.w2_qzeros, + size_k=layer.w2_qzeros.shape[1], + size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, + num_bits=self.quant_config.weight_bits) + replace_parameter(layer, "w2_qzeros", marlin_w2_zp) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + w1_zeros=layer.w13_qzeros, + w2_zeros=layer.w2_qzeros, + num_bits=self.quant_config.weight_bits, + ) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 0000000..bbb7fc8 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,317 @@ +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + # Dequantize b. + offsets_szk = ( + (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + + tl.arange(0, 1)) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton(qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), + split_k_iters, + ) + + result = torch.zeros((split_k_iters, M, N), + dtype=scales.dtype, + device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters) + + result = result.sum(0) + + return result diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py new file mode 100644 index 0000000..75fa824 --- /dev/null +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -0,0 +1,143 @@ +import inspect +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Type + +import torch +from torch import nn + + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, *weight_args, + **extra_weight_attrs): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + # Not required functions + def embedding(self, layer: torch.nn.Module, *args, + **kwargs) -> torch.Tensor: + """Gather embeddings in the layer based on indices in the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return + + +def method_has_implemented_embedding( + method_class: Type[QuantizeMethodBase]) -> bool: + """ + Not all quant methods have embedding implemented, so we need to check that + it exists for our given method. We check this by making sure the function + has been changed from the base implementation. + """ + base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", + None) + class_embedding = inspect.getattr_static(method_class, "embedding", None) + + return (class_embedding is not None + and class_embedding is not base_embedding) + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + @abstractmethod + def get_name(self) -> str: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + """ + Detects if this quantization method can support a given checkpoint + format by overriding the user specified quantization method -- + this method should only be overwritten by subclasses in exceptional + circumstances + """ + return None + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "quantization config.") + + @staticmethod + def get_from_keys_or(config: Dict[str, Any], keys: List[str], + default: Any) -> Any: + """Get a optional value from the model's quantization config.""" + try: + return QuantizationConfig.get_from_keys(config, keys) + except ValueError: + return default + + @abstractmethod + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + prefix: The full name of the layer in the state dict + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ + raise NotImplementedError + + @abstractmethod + def get_scaled_act_names(self) -> List[str]: + """Returns the activation function names that should be post-scaled. + + For now, this is only used by AWQ. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py new file mode 100644 index 0000000..faa8d92 --- /dev/null +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -0,0 +1,316 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class BitsAndBytesConfig(QuantizationConfig): + """Config class for BitsAndBytes Quantization. + + Reference: https://arxiv.org/abs/2305.14314 + """ + + def __init__( + self, + load_in_8bit: bool = False, + load_in_4bit: bool = True, + bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_type: str = "fp4", + bnb_4bit_use_double_quant: bool = False, + llm_int8_enable_fp32_cpu_offload: bool = False, + llm_int8_has_fp16_weight: bool = False, + llm_int8_skip_modules: Optional[Any] = None, + llm_int8_threshold: float = 0.0, + ) -> None: + + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_threshold = llm_int8_threshold + + def __repr__(self) -> str: + return "BitsAndBytesConfig" + + @classmethod + def get_name(self) -> str: + return "bitsandbytes" + + @classmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float32, torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "adapter_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": + + def get_safe_value(config, keys, default_value=None): + try: + value = cls.get_from_keys(config, keys) + return value if value is not None else default_value + except ValueError: + return default_value + + load_in_8bit = get_safe_value(config, ["load_in_8bit"], + default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], + default_value=True) + bnb_4bit_compute_dtype = get_safe_value(config, + ["bnb_4bit_compute_dtype"], + default_value="float32") + bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], + default_value="fp4") + bnb_4bit_use_double_quant = get_safe_value( + config, ["bnb_4bit_use_double_quant"], default_value=False) + llm_int8_enable_fp32_cpu_offload = get_safe_value( + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) + llm_int8_has_fp16_weight = get_safe_value(config, + ["llm_int8_has_fp16_weight"], + default_value=False) + llm_int8_skip_modules = get_safe_value(config, + ["llm_int8_skip_modules"], + default_value=[]) + llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], + default_value=0.0) + + return cls( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + llm_int8_skip_modules=llm_int8_skip_modules, + llm_int8_threshold=llm_int8_threshold) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["BitsAndBytesLinearMethod"]: + if isinstance(layer, LinearBase): + return BitsAndBytesLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class BitsAndBytesLinearMethod(LinearMethodBase): + """Linear method for BitsAndBytes. + + Args: + quant_config: The BitsAndBytes quantization config. + """ + + def __init__(self, quant_config: BitsAndBytesConfig): + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.44.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.44.0.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " + "bitsandbytes quantizer.") from err + + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + from bitsandbytes.nn import Int8Params + + def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + def create_qweight_for_8bit(): + qweight = Int8Params( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": 1, + "use_bitsandbytes_8bit": True, + "generation": 0 + }) + return qweight + + def create_qweight_for_4bit(): + quant_ratio = calculate_quant_ratio(params_dtype) + + total_size = input_size_per_partition * sum(output_partition_sizes) + if total_size % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape.") + + qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, + 1, + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True + }) + return qweight + + if self.quant_config.load_in_8bit: + qweight = create_qweight_for_8bit() + else: + qweight = create_qweight_for_4bit() + + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.quant_config.load_in_8bit: + return self._apply_8bit_weight(layer, x, bias) + else: + return self._apply_4bit_weight(layer, x, bias) + + def _apply_8bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import MatmulLtState, matmul + + original_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + offsets = qweight.bnb_shard_offsets + quant_states = qweight.bnb_quant_state + matmul_states = qweight.matmul_state + generation = qweight.generation + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.float16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + + # in profile_run or the first generation of inference, + # create new matmul_states + if generation == 0 or generation == 1: + matmul_states[i] = MatmulLtState() + matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].SCB = quant_states[i].to(x.device) + matmul_states[i].threshold = ( + self.quant_config.llm_int8_threshold) + matmul_states[i].has_fp16_weights = ( + self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].is_training = False + if matmul_states[i].threshold > 0.0 and not matmul_states[ + i].has_fp16_weights: + matmul_states[i].use_pool = True + + new_x = bf_x.unsqueeze(0) + + out[:, current_index:current_index + output_size] = matmul( + new_x, + qweight[offsets[i]:offsets[i + 1]], + state=matmul_states[i]) + + current_index += output_size + + # only update the matmul_states if it is not profile_run + if (generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None): + del matmul_states[i].CB + qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + + out = out.to(original_type) + + if bias is not None: + out += bias + + qweight.generation += 1 + + return out + + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import matmul_4bit + + original_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + quant_states = qweight.bnb_quant_state + offsets = qweight.bnb_shard_offsets + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.bfloat16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + # It is more efficient to use out kwarg like + # matmul_4bit(..., out = ...). Infeasible now due to the bug + # https://github.com/TimDettmers/bitsandbytes/issues/1235. + # Need to change after the bug is fixed. + out[:, current_index:current_index + output_size] = matmul_4bit( + bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i]) + + current_index += output_size + + out = out.to(original_type) + + if bias is not None: + out += bias + + return out diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6452a20 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors.cpython-310.pyc new file mode 100644 index 0000000..d9325cc Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors_moe.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors_moe.cpython-310.pyc new file mode 100644 index 0000000..73b4a52 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/compressed_tensors_moe.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/utils.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..1e28ce3 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py new file mode 100644 index 0000000..abb18d3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -0,0 +1,412 @@ +from typing import Any, Dict, List, Optional, cast + +import torch +from pydantic import BaseModel + +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, + CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + CompressionFormat, QuantizationArgs, QuantizationStrategy, + QuantizationType, find_matched_target, is_activation_quantization_format, + should_ignore_layer) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import current_platform + +__all__ = ["CompressedTensorsLinearMethod"] + + +class CompressedTensorsConfig(QuantizationConfig): + + def __init__(self, + target_scheme_map: Dict[str, Any], + ignore: List[str], + quant_format: str, + kv_cache_scheme: Optional[Dict[str, Any]] = None): + + self.ignore = ignore + self.quant_format = quant_format + # Map from [target -> scheme] + self.target_scheme_map = target_scheme_map + self.kv_cache_scheme = kv_cache_scheme + + def get_linear_method(self) -> "CompressedTensorsLinearMethod": + return CompressedTensorsLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "compressed_tensors" + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + # TODO (@robertgshaw2): support module names + if should_ignore_layer(prefix, ignore=self.ignore): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return CompressedTensorsLinearMethod(self) + if isinstance(layer, Attention): + return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod.get_moe_method(self) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": + target_scheme_map: Dict[str, Any] = dict() + ignore = cast(List[str], config.get("ignore")) + quant_format = cast(str, config.get("format")) + + # The quant_config has multiple config_groups, each containing + # an input_activations key with details about how the activations are + # quantized, a weights key indicating how the weights are quantized, + # and a list of targets under the `targets` key, dictating which + # layers are impacted by the quantization details. The quantization + # details follow the structure defined by the QuantizationArgs + # pydantic model, which is used to verify the structure of the + # quant_config and also store the details for later use. + for _, quant_config in config["config_groups"].items(): + targets = quant_config.get("targets") + for target in targets: + target_scheme_map[target] = {} + target_scheme_map[target][ + "weights"] = QuantizationArgs.parse_obj( + quant_config.get("weights")) + try: + target_scheme_map[target][ + "input_activations"] = QuantizationArgs.parse_obj( + quant_config.get("input_activations")) + except Exception: + target_scheme_map[target]["input_activations"] = None + + return cls(target_scheme_map=target_scheme_map, + ignore=ignore, + quant_format=quant_format, + kv_cache_scheme=config.get("kv_cache_scheme")) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _check_scheme_supported(self, + min_capability: int, + error: bool = True) -> bool: + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False + + def _is_static_tensor_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_tensor = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TENSOR.value) + is_static = not weight_quant.dynamic and not input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_tensor and weight_quant.symmetric and is_static + + def _is_dynamic_token_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8 + weight_strategy = ( + weight_quant.strategy == QuantizationStrategy.TENSOR.value + or weight_quant.strategy == QuantizationStrategy.CHANNEL.value) + is_token = (weight_strategy and input_quant.strategy + == QuantizationStrategy.TOKEN.value) + is_dynamic = not weight_quant.dynamic and input_quant.dynamic + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_8_bits and is_token and weight_quant.symmetric and is_dynamic + + def _is_fp8_w8a8(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights and activations quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported. + is_floating_point = (weight_quant.type == QuantizationType.FLOAT + and input_quant.type == QuantizationType.FLOAT) + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_floating_point and is_symmetric_weight and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.dynamic: + return True + + # Confirm activation scheme is supported. + is_symmetric_activation = input_quant.symmetric + is_per_tensor_activation = ( + input_quant.strategy == QuantizationStrategy.TENSOR) + return is_symmetric_activation and is_per_tensor_activation + + def _is_fp8_w8a16(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + # Confirm weights quantized. + if weight_quant is None: + return False + + # Confirm we have floating points. + if weight_quant.type != QuantizationType.FLOAT: + return False + + # Confirm weight scheme is supported. + is_symmetric_weight = weight_quant.symmetric + is_static_weight = not weight_quant.dynamic + is_per_tensor_or_channel_weight = (weight_quant.strategy in [ + QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL + ]) + if not (is_symmetric_weight and is_static_weight # noqa: SIM103 + and is_per_tensor_or_channel_weight): + return False + + # All conditions satisfied. + return True + + def _is_wNa16_group_channel(self, weight_quant: BaseModel, + input_quant: BaseModel) -> bool: + input_quant_none = input_quant is None + is_symmetric = weight_quant.symmetric + is_channel_group = ( + weight_quant.strategy == QuantizationStrategy.CHANNEL.value + or weight_quant.strategy == QuantizationStrategy.GROUP.value) + is_static = not weight_quant.dynamic + + return (is_channel_group and input_quant_none and is_symmetric + and is_static) + + def _get_scheme_from_parts( + self, weight_quant: BaseModel, + input_quant: BaseModel) -> "CompressedTensorsScheme": + + # Detect If Mixed Precision + if self._is_wNa16_group_channel(weight_quant, input_quant): + if (self.quant_format == CompressionFormat.marlin_24.value + and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): + return CompressedTensorsW4A16Sparse24( + strategy=weight_quant.strategy, + num_bits=weight_quant.num_bits, + group_size=weight_quant.group_size) + if (self.quant_format == CompressionFormat.pack_quantized.value + and weight_quant.num_bits in WNA16_SUPPORTED_BITS): + return CompressedTensorsWNA16( + num_bits=weight_quant.num_bits, + strategy=weight_quant.strategy, + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) + + # Detect If Activation Quantization. + # TODO @dsikka: clean-up conditions + if is_activation_quantization_format(self.quant_format): + if self._is_fp8_w8a8(weight_quant, input_quant): + is_fp8_w8a8_supported = self._check_scheme_supported( + CompressedTensorsW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + return CompressedTensorsW8A8Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) + else: + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) + + if self._is_fp8_w8a16(weight_quant, input_quant): + return CompressedTensorsW8A16Fp8( + strategy=weight_quant.strategy, + is_static_input_scheme=(input_quant + and not input_quant.dynamic)) + + if self._is_static_tensor_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=True, + input_symmetric=input_quant.symmetric) + + if self._is_dynamic_token_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Int8( + strategy=weight_quant.strategy, + is_static_input_scheme=False, + input_symmetric=input_quant.symmetric) + + raise NotImplementedError( + "No compressed-tensors compatible scheme was found.") + + def get_scheme( + self, + layer: torch.nn.Module, + layer_name: Optional[str] = None) -> "CompressedTensorsScheme": + """ + compressed-tensors supports non uniform in the following way: + + ignore: List of layer_names or nn.Module names to be ignored. + targets of config_groups: There can be N config_groups which each + have a quantization scheme. Each config_group has a list of targets + which can be a full layer_name, a regex for a layer_name, or + an nn.Module name. + + We first check whether a layer is in the ignore group and use + CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer + + We then detect whether a layer_name is found in any target and + use the quantization scheme corresponding to the matched target + to select the CompressedTensorsScheme used for infernece. + """ + + # Find the "target" in the compressed-tensors config + # that our layer conforms to. + # TODO (@robertgshaw): add compressed-tensors as dep + # so we do not have to re-write these functions + # need to make accelerate optional in ct to do this + matched_target = find_matched_target( + layer_name=layer_name, + module=layer, + targets=self.target_scheme_map.keys()) + + # Find the quant_scheme + scheme_dict = self.target_scheme_map[matched_target] + scheme = self._get_scheme_from_parts( + weight_quant=scheme_dict["weights"], + input_quant=scheme_dict["input_activations"]) + + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + + +class CompressedTensorsLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: CompressedTensorsConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class CompressedTensorsKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from compressed-tensors + checkpoints. + """ + + def __init__(self, quant_config: CompressedTensorsConfig): + self.validate_kv_cache_scheme(quant_config.kv_cache_scheme) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]): + """ + Validator for the kv cache scheme. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_scheme: the compressed-tensors kv cache scheme + """ + if kv_cache_scheme is None: + return + + type_ = kv_cache_scheme.get("type") + num_bits = kv_cache_scheme.get("num_bits") + + if type_ != "float" and num_bits != 8: + raise NotImplementedError( + "Currently supported kv cache quantization is " + "num_bits=8, type=float, however " + f"received num_bits={num_bits}, type={type_}") + + strategy = kv_cache_scheme.get("strategy") + if strategy != "tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for compressed-tensors KV cache. " + f"Expected strategy: tensor, found strategy: {strategy}") + + is_symmetric = kv_cache_scheme.get("symmetric") + if not is_symmetric: + raise NotImplementedError( + "Only support symmetric scaling factor " + "for compressed-tensors KV cache. " + f"However found symmetric: {is_symmetric}") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 0000000..af04d72 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,511 @@ +import enum +from enum import Enum +from typing import Callable, List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + CompressionFormat, QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_hip, print_warning_once + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = [ + "CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod", + "CompressedTensorsWNA16MoEMethod" +] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + return CompressedTensorsWNA16MoEMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}") + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get( + "weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations") + + if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales" + "for weights and activations are supported. Found " + f"{self.weight_quant}, {self.input_quant}") + + self.static_input_scales = not self.input_quant.dynamic + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # If rocm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy.value + self.group_size = config.group_size + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // + self.packed_factor, + 2 * intermediate_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size // + self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = intermediate_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int, num_bits: int): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, + scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): + num_experts = s.shape[0] + output = torch.empty((num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + size_k2 = layer.w2_weight_packed.shape[2] + size_k13 = layer.w13_weight_packed.shape[2] + + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + size_k13, + layer.w13_weight_scale.shape[2], + self.group_size, + self.num_bits, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] * self.packed_factor, + size_k2, + self.group_size, + self.num_bits, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py new file mode 100644 index 0000000..5d259ec --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -0,0 +1,19 @@ +from .compressed_tensors_scheme import CompressedTensorsScheme +from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, + CompressedTensorsW4A16Sparse24) +from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 +from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8 +from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8 +from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS, + CompressedTensorsWNA16) + +__all__ = [ + "CompressedTensorsScheme", + "CompressedTensorsWNA16", + "CompressedTensorsW8A16Fp8", + "CompressedTensorsW4A16Sparse24", + "CompressedTensorsW8A8Int8", + "CompressedTensorsW8A8Fp8", + "WNA16_SUPPORTED_BITS", + "W4A16SPARSE24_SUPPORTED_BITS", +] diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..b30bdb5 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_scheme.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_scheme.cpython-310.pyc new file mode 100644 index 0000000..8bdd453 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_scheme.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w4a16_24.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w4a16_24.cpython-310.pyc new file mode 100644 index 0000000..4c5d1b6 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w4a16_24.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a16_fp8.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a16_fp8.cpython-310.pyc new file mode 100644 index 0000000..2f1b90c Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a16_fp8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a8_fp8.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a8_fp8.cpython-310.pyc new file mode 100644 index 0000000..be73cc2 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a8_fp8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a8_int8.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a8_int8.cpython-310.pyc new file mode 100644 index 0000000..8776b45 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_w8a8_int8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_wNa16.cpython-310.pyc b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_wNa16.cpython-310.pyc new file mode 100644 index 0000000..7fea985 Binary files /dev/null and b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__pycache__/compressed_tensors_wNa16.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py new file mode 100644 index 0000000..b4bab33 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["CompressedTensorsScheme"] + + +class CompressedTensorsScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by CompressedTensors. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py new file mode 100644 index 0000000..9ad61a6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py @@ -0,0 +1,153 @@ +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +__all__ = ["CompressedTensorsW4A16Sparse24"] +W4A16SPARSE24_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, +} +W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None): + self.strategy = strategy + self.group_size = group_size + self.tile_size = 16 + + if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}") + + self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits] + + if self.strategy == "group" and self.group_size is None: + raise ValueError( + "group_size must be given when using strategy group") + + @classmethod + def get_min_capability(cls) -> int: + # ampere + up + return 80 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile to be torch.nn.Parameter + layer.weight_packed = Parameter(layer.weight_packed.data, + requires_grad=False) + layer.scale_packed = Parameter(layer.scale_packed.data, + requires_grad=False) + layer.meta = Parameter(layer.meta.data, requires_grad=False) + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + pack_factor = 32 // self.quant_type.size_bits + output_size_per_partition = sum(output_partition_sizes) + + qweight = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // self.tile_size // 2, + output_size_per_partition * self.tile_size // pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=pack_factor, + marlin_tile_size=self.tile_size, + weight_loader=weight_loader) + + input_groups = (1 if self.group_size is None else + input_size_per_partition // self.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if self.group_size is not None: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + else: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", qweight) + layer.register_parameter("weight_shape", weight_shape) + layer.register_parameter("scale_packed", scales) + layer.register_parameter("meta", meta) + + max_workspace_size = ( + output_size_per_partition // + GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL + + workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int), + requires_grad=False) + layer.workspace = workspace + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + qweight = layer.weight_packed + meta = layer.meta + scales = layer.scale_packed + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, self.quant_type, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py new file mode 100644 index 0000000..3d55d55 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py @@ -0,0 +1,118 @@ +from typing import Callable, List, Optional + +import torch + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +__all__ = ["CompressedTensorsW8A16Fp8"] + +SUPPORTED_STRATEGIES = [ + QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR +] + + +class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + # W8A8-Fp8 kernels support only per-tensor and per-channel cases. + # So if we have a fused module (QKV, MLP) with per tensor scales, + # we expand each scale to its shard's channels. + def process_weights_after_loading(self, layer) -> None: + if self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + layer.weight_scale = torch.nn.Parameter(ws_channelwise, + requires_grad=False) + else: + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + + # Weights must be transposed for marlin + layer.weight = torch.nn.Parameter(layer.weight.t(), + requires_grad=False) + + if self.is_static_input_scheme: + # required by torch.compile to be torch.nn.Parameter + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + prepare_fp8_layer_for_marlin(layer, strategy="channel") + + def create_weights(self, layer: torch.nn.Module, input_size: int, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + elif self.strategy == QuantizationStrategy.TENSOR: + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + else: + raise ValueError( + f"Unsupported weight strategy={self.strategy}, " + f"supported strategies are {SUPPORTED_STRATEGIES}") + + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE (to deal with converted checkpoints) + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_marlin_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py new file mode 100644 index 0000000..5931ec3 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -0,0 +1,143 @@ +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.utils import is_hip + +__all__ = ["CompressedTensorsW8A8Fp8"] + + +class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + self.cutlass_fp8_supported = cutlass_fp8_supported() + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.strategy == QuantizationStrategy.TENSOR: + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + if is_hip(): + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.strategy == QuantizationStrategy.CHANNEL: + weight = layer.weight + + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError(f"Unknown quantization strategy {self.strategy}") + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py new file mode 100644 index 0000000..b93f35e --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py @@ -0,0 +1,155 @@ +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + QuantizationStrategy) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_int8_linear, convert_to_channelwise) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + + +class CompressedTensorsW8A8Int8(CompressedTensorsScheme): + + def __init__(self, strategy: str, is_static_input_scheme: bool, + input_symmetric: bool): + self.strategy = strategy + self.is_static_input_scheme = is_static_input_scheme + self.input_symmetric = input_symmetric + + @classmethod + def get_min_capability(cls) -> int: + # turing and up + return 75 + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = layer.weight + layer.weight = Parameter(weight.t(), requires_grad=False) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(self.logical_widths) > 1 + if is_fused_module and self.strategy == QuantizationStrategy.TENSOR: + ws_channelwise = convert_to_channelwise(layer.weight_scale, + self.logical_widths) + layer.weight_scale = Parameter(ws_channelwise, requires_grad=False) + else: + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) + # INPUT SCALE + if self.is_static_input_scheme: + if self.input_symmetric: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + layer.input_zero_point = None + else: + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = layer.input_zero_point.to(dtype=torch.int32) + range_max = (layer.input_scale * + (int8_traits.max - azps)).max() + range_min = (layer.input_scale * + (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + layer.input_scale = Parameter(scale, requires_grad=False) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + layer.input_zero_point = Parameter(azp, requires_grad=False) + + else: + layer.input_scale = None + layer.input_zero_point = None + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.input_symmetric: + layer.azp_adj = layer.weight.sum(dim=0, + keepdim=True, + dtype=torch.int32) + else: + layer.azp_adj = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + self.logical_widths = output_partition_sizes + + # WEIGHT + if input_size_per_partition % 64 != 0: + pad_input_size_per_partition = (input_size_per_partition // 64 + 1) * 64 + else: + pad_input_size_per_partition = input_size_per_partition + w_pad = torch.zeros( + sum(output_partition_sizes), + pad_input_size_per_partition, + dtype=torch.int8) + w = w_pad[:, :input_size_per_partition] + weight = ModelWeightParameter(data=w, + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.strategy == QuantizationStrategy.CHANNEL: + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.strategy == QuantizationStrategy.TENSOR + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: compressed-tensors stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + + return apply_int8_linear(input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + input_zero_point=layer.input_zero_point, + azp_adj=layer.azp_adj, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py new file mode 100644 index 0000000..cb65557 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -0,0 +1,163 @@ +from typing import Callable, List, Optional, Set + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + ActivationOrdering) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + marlin_repeat_scales_on_all_ranks) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +__all__ = ["CompressedTensorsWNA16"] +WNA16_SUPPORTED_TYPES_MAP = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128 +} +WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) + + +class CompressedTensorsWNA16(CompressedTensorsScheme): + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, + strategy: str, + num_bits: int, + group_size: Optional[int] = None, + actorder: Optional[ActivationOrdering] = None): + + self.pack_factor = 32 // num_bits + self.strategy = strategy + self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP + + if self.group_size == -1 and self.strategy != "channel": + raise ValueError("Marlin kernels require group quantization or " + "channelwise quantization, but found no group " + "size and strategy is not channelwise.") + + if num_bits not in WNA16_SUPPORTED_TYPES_MAP: + raise ValueError( + f"Unsupported num_bits = {num_bits}. " + f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}") + + self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] + + @classmethod + def get_min_capability(cls) -> int: + # ampere and up + return 80 + + def create_weights(self, layer: torch.nn.Module, output_size: int, + input_size: int, output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + + output_size_per_partition = sum(output_partition_sizes) + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_type, + act_type=params_dtype, + group_size=self.group_size, + zero_points=False, + has_g_idx=self.has_g_idx + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for CompressedTensorsWNA16", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # If group_size is -1, we are in channelwise case. + group_size = self.group_size if self.group_size != -1 else input_size + row_parallel = (input_size != input_size_per_partition) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) + + scales_and_zp_size = input_size // group_size + + if partition_scales: + assert input_size_per_partition % group_size == 0 + scales_and_zp_size = input_size_per_partition // group_size + + weight = PackedvLLMParameter(input_dim=1, + output_dim=0, + weight_loader=weight_loader, + packed_factor=self.pack_factor, + packed_dim=1, + data=torch.empty( + output_size_per_partition, + input_size_per_partition // + self.pack_factor, + dtype=torch.int32, + )) + + weight_scale_args = { + "weight_loader": + weight_loader, + "data": + torch.empty( + output_size_per_partition, + scales_and_zp_size, + dtype=params_dtype, + ) + } + if not partition_scales: + weight_scale = ChannelQuantScaleParameter(output_dim=0, + **weight_scale_args) + else: + weight_scale = GroupQuantScaleParameter(output_dim=0, + input_dim=1, + **weight_scale_args) + + # A 2D array defining the original shape of the weights + # before packing + weight_shape = BasevLLMParameter(data=torch.empty(2, + dtype=torch.int64), + weight_loader=weight_loader) + + layer.register_parameter("weight_packed", weight) + layer.register_parameter("weight_scale", weight_scale) + layer.register_parameter("weight_shape", weight_shape) + + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="weight_packed", + w_s_param_name="weight_scale", + w_zp_param_name=None, + w_gidx_param_name="weight_g_idx") + + # Checkpoints are serialized in compressed-tensors format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py new file mode 100644 index 0000000..fc531b9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -0,0 +1,269 @@ +import re +from enum import Enum +from typing import Any, Dict, Iterable, Optional, Union + +from pydantic import BaseModel, Field, field_validator +from torch.nn import Module + +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + FUSED_LAYER_NAME_MAPPING) + + +class CompressionFormat(Enum): + dense = "dense" + sparse_bitmask = "sparse-bitmask" + naive_quantized = "naive-quantized" + float_quantized = "float-quantized" + int_quantized = "int-quantized" + pack_quantized = "pack-quantized" + marlin_24 = "marlin-24" + + +class QuantizationType(str, Enum): + """ + Enum storing quantization type options + """ + + INT = "int" + FLOAT = "float" + + +class QuantizationStrategy(str, Enum): + """ + Enum storing quantization strategy options + """ + + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + BLOCK = "block" + TOKEN = "token" + + +class ActivationOrdering(str, Enum): + """ + Enum storing strategies for activation ordering + + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower latency and + accuracy compared to group actorder\n + """ + + GROUP = "group" + WEIGHT = "weight" + + +class QuantizationArgs(BaseModel): + """ + User facing arguments used to define a quantization config + for weights or activations + + :param num_bits: quantization bit depth + :param type: dtype to quantized to, either int or float + :param symmetric: whether or not quantization scale is symmetric + :param strategy: string determining the scope of scale/zero-point to apply + :param group_size: group length to use for the group strategy + :param block_structure: 2d block structure to use for the block + strategy, must be of the format "2x4", "8x16", etc. + :param dynamic: set True to perform dynamic quantization - + values will not be calibrated during calibration phase, + instead during inference new quantization ranges will be + observed with every sample. Defaults to False for static + quantization. Note that enabling dynamic quantization + will change the default observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering + """ + + num_bits: int = 8 + type: QuantizationType = QuantizationType.INT + symmetric: bool = True + group_size: Optional[int] = None + strategy: Optional[QuantizationStrategy] = None + block_structure: Optional[str] = None + dynamic: bool = False + actorder: Union[ActivationOrdering, bool, None] = None + observer: str = Field( + default="minmax", + description=("The class to use to compute the quantization param - " + "scale and zero-point'"), + ) + observer_kwargs: Dict[str, Any] = Field( + default_factory=dict, + description= + ("optional dict of kwargs to be passed directly to torch quantization " + "Observers constructor excluding quantization range or symmetry"), + ) + + @field_validator("actorder", mode="before") + def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + + if isinstance(value, str): + return ActivationOrdering(value.lower()) + + return value + + +def is_activation_quantization_format(format: str) -> bool: + _ACTIVATION_QUANTIZATION_FORMATS = [ + CompressionFormat.naive_quantized.value, + CompressionFormat.int_quantized.value, + CompressionFormat.float_quantized.value + ] + return format in _ACTIVATION_QUANTIZATION_FORMATS + + +def should_ignore_layer(layer_name: Optional[str], + ignore: Iterable[str]) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def find_matched_target(layer_name: Optional[str], module: Module, + targets: Iterable[str]) -> str: + """ + Helper function to look up which "target" in the compressed-tensors + config that a layer corresponds to. + + Recall that a compressed-tensors configs has a concept of + config_groups, where each layer can be quantized with with a different + scheme. + + targets in each config_group will be a list of either layer names + (or regexes corresponding to layer names) or names of torch Modules. + + First, we try to match the layer_name with a target + Second, we try to match the module's name with a target + + :param layer_name: layer name + :param module: torch.nn.Module + :param targets: list of targets to match the layer against + """ + + if layer_name is None: + layer_name = "" + + matched_target = (_find_first_match(layer_name, targets) + or _find_first_match(module.__class__.__name__, targets, + True)) + + if matched_target is None: + raise ValueError(f"Unable to find matching target for {module} in the " + "compressed-tensors config.") + + return matched_target + + +def _find_first_match(value: str, + targets: Iterable[str], + check_contains: bool = False) -> Optional[str]: + """ + Returns first element of target that matches value either + exactly or as a regex after 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + + :param value: string to compare the list of targets against + :param targets: list of targets to match the layer against + :param check_contains: whether or not to do a substring match + """ + + for target in targets: + if _is_equal_or_regex_match(value, + target, + check_contains=check_contains): + return target + return None + + +def get_compressed_tensors_cache_scale(name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + # If no matches, return None + return None + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py new file mode 100644 index 0000000..2948480 --- /dev/null +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -0,0 +1,193 @@ +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + + +class DeepSpeedFPConfig(QuantizationConfig): + """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. + + Args: + weight_bits: the target quantization bits, 6 or 8. + group_size: group size for quantizaiton, default to 128. + """ + + def __init__( + self, + weight_bits: int = 8, + group_size: int = 512, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.valid_types = [torch.bfloat16, torch.float16] + + if self.weight_bits not in (6, 8): + raise ValueError( + "Currently, only 6-bit or 8-bit weight quantization are " + f"supported for DeepSpeed FP quantizaiton, but got " + f"{self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), " + f"group_size={self.group_size}") + + @classmethod + def get_name(cls) -> str: + return "DeepSpeedFP" + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits=weight_bits, group_size=group_size) + + def get_linear_method(self) -> "DeepSpeedFPLinearMethod": + return DeepSpeedFPLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["DeepSpeedFPLinearMethod"]: + if isinstance(layer, LinearBase): + return DeepSpeedFPLinearMethod(self) + return None + + +class DeepSpeedFPLinearMethod(LinearMethodBase): + """Linear method for DeepSpeedFP quantizer. + + Args: + quant_config: the DeepSpeedFP quantization config. + """ + + def __init__(self, quant_config: DeepSpeedFPConfig): + self.quant_config = quant_config + self.weight = None + + def create_weights(self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_loader=None, + **extra_weight_attrs): + del output_size + del input_size + output_size_per_partition = sum(output_partition_sizes) + weight = DeepSpeedFPParameter( + torch.Size((output_size_per_partition, input_size_per_partition)), + params_dtype=params_dtype, + quant_config=self.quant_config, + ) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + }) + layer.register_parameter("weight", weight) + + def quant_weight_loader(param, loaded_weight, *args, **kwargs): + # Calls the original weight loader (if any), quantizes the result, + # and then loads the quantized parameter. + if weight_loader is not None: + orig_param_data = param.data + param.data = param.ds_dequantize() + weight_loader(param, loaded_weight, *args, **kwargs) + param.data, loaded_weight = orig_param_data, param.data + param.ds_quantize_(loaded_weight.cuda()) + + extra_weight_attrs["weight_loader"] = quant_weight_loader + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) + + +class DeepSpeedFPParameter(nn.Parameter): + """ + DeepSpeedFP quantized parameter class that implements fp8/fp6 + quantization deepspeed. Weights are stored in quantized form on + GPUs, and can be dequantized on-the-fly when needed by the model. + """ + + def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, + quant_config: DeepSpeedFPConfig): + try: + import deepspeed + if deepspeed.__version__ < "0.14.2": + raise ImportError("deepspeed version is wrong. Please " + "install deepspeed>=0.14.2.") + from deepspeed.ops.fp_quantizer import FP_Quantize + except ImportError as err: + raise ImportError("Please install deepspeed>=0.14.2 via " + "`pip install deepspeed>=0.14.2` to use " + "deepspeedfp quantizer.") from err + data = torch.empty(( + orig_shape.numel() // quant_config.group_size, + quant_config.group_size * quant_config.weight_bits // 8 + 4, + ), + dtype=torch.int8) + self = torch.Tensor._make_subclass(cls, data, data.requires_grad) + self.orig_shape = orig_shape + self.quant_config = quant_config + self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size) + self.fp_quantizer.orig_shape = orig_shape + self.fp_quantizer.orig_dtype = params_dtype + return self + + def ds_quantize_(self, tensor: torch.Tensor): + assert tensor.device.type == "cuda" and tensor.dtype != torch.int8 + return self.data.copy_( + self.fp_quantizer.quantize( + tensor.data, + q_bits=self.quant_config.weight_bits, + )) + + def ds_dequantize(self, fp_out=None) -> torch.Tensor: + """ + Return a tensor containing the dequantized weights of this parameter. + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.dequantize( + self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + + def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: + """ + Return a tensor where only the weights at `indices` are dequantized + (to save HBM -> SRAM bandwidth). + """ + assert self.data.device.type == "cuda" and self.data.dtype == torch.int8 + return self.fp_quantizer.selective_dequantize( + self.data, + indices, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits) diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py new file mode 100644 index 0000000..116a4ea --- /dev/null +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -0,0 +1,179 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch + +from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs + + +class ExpertsInt8Config(QuantizationConfig): + """Config class for Int8 experts quantization.""" + + def __init__(self) -> None: + pass + + @classmethod + def get_name(cls) -> str: + return "experts_int8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ExpertsInt8Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return UnquantizedLinearMethod() + elif isinstance(layer, FusedMoE): + return ExpertsInt8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ExpertsInt8MoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: ExpertsInt8Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + int8_dtype = torch.int8 + + assert 'weight_loader' in extra_weight_attrs + weight_loader = extra_weight_attrs['weight_loader'] + wrapped_weight_loader = ExpertsInt8MoEMethod.quantizing_weight_loader( + layer, weight_loader) + extra_weight_attrs['weight_loader'] = wrapped_weight_loader + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=int8_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=int8_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w13_scale = torch.nn.Parameter(torch.zeros(num_experts, + 2 * intermediate_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter(torch.zeros(num_experts, + hidden_size, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_scale", w2_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int8_w8a16=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale) + + @staticmethod + def quantizing_weight_loader(layer, weight_loader): + + def quantize_and_call_weight_loader(param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, shard_id: int, + expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + shard_size = layer.intermediate_size_per_partition + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + device = get_tp_group().device + loaded_weight = loaded_weight.to(device) + # w1, gate_proj case: Load into first shard of w13. + if shard_id == "w1": + scales = quantize_in_place_and_get_scales( + loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, 0:shard_size].copy_(scales[:, + 0]) + # w3, up_proj case: Load into second shard of w13. + elif shard_id == "w3": + scales = quantize_in_place_and_get_scales( + loaded_weight[shard, :]) + layer.w13_scale.data[expert_id, shard_size:2 * + shard_size].copy_(scales[:, 0]) + # w2, down_proj case: Load into only shard of w2. + elif shard_id == "w2": + scales = quantize_in_place_and_get_scales(loaded_weight[:, + shard]) + layer.w2_scale.data[expert_id, :].copy_(scales[:, 0]) + else: + raise ValueError( + f"Shard id must be in [0,1,2] but got {shard_id}") + weight_loader(param, loaded_weight, weight_name, shard_id, + expert_id) + + return quantize_and_call_weight_loader + + +def quantize_in_place_and_get_scales(weight: torch.Tensor) -> torch.Tensor: + vmax = torch.iinfo(torch.int8).max + scales = (torch.max(torch.abs(weight), dim=1, keepdim=True)[0] / vmax) + + weight.div_(scales) + weight.round_() + weight.clamp_(-vmax, vmax) + + return scales diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py new file mode 100644 index 0000000..f269071 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -0,0 +1,169 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.fp8 import cutlass_fp8_supported +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, normalize_e4m3fn_to_e4m3fnuz) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) +from vllm.platforms import current_platform +from vllm.utils import is_hip + +logger = init_logger(__name__) + + +class FBGEMMFp8Config(QuantizationConfig): + """Config class for FBGEMM Fp8.""" + + def __init__(self, ignore_list: List[str], input_scale_ub: float): + self.ignore_list = ignore_list if ignore_list else [] + self.input_scale_ub = input_scale_ub + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = not current_platform.has_device_capability(89) + + @classmethod + def get_name(cls) -> str: + return "fbgemm_fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config": + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignore_list): + return UnquantizedLinearMethod() + return FBGEMMFp8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class FBGEMMFp8LinearMethod(LinearMethodBase): + + def __init__(self, quant_config: FBGEMMFp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs.get("weight_loader") + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = ChannelQuantScaleParameter(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE UPPER BOUND + input_scale_ub = torch.nn.Parameter(torch.tensor( + (self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False) + layer.input_scale_ub = input_scale_ub + + def process_weights_after_loading(self, layer: Module) -> None: + # required by torch.compile + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + weight = layer.weight + + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=None) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + if self.quant_config.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale_ub + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.quant_config.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=True) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py new file mode 100644 index 0000000..b5feb55 --- /dev/null +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -0,0 +1,514 @@ +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, apply_fp8_linear, convert_to_channelwise, + cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import is_hip, print_warning_once + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = init_logger(__name__) + + +class Fp8Config(QuantizationConfig): + """Config class for FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected fp8 checkpoint. Please note that the " + "format is experimental and subject to change.") + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + + @classmethod + def get_name(cls) -> str: + return "fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_fp8_serialized = ("fp8" in quant_method) + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return Fp8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return Fp8MoEMethod(self) + elif isinstance(layer, Attention): + return Fp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class Fp8LinearMethod(LinearMethodBase): + """Linear method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn data type due to the limitation of + torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856) + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + # For GPUs that lack FP8 hardware support, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + self.use_marlin = (not current_platform.has_device_capability(89) + or envs.VLLM_TEST_FORCE_FP8_MARLIN) + # Disable marlin for rocm + if is_hip(): + self.use_marlin = False + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # If checkpoint is serialized fp8, load them. + # Otherwise, wait until process_weights_after_loading. + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) + + # INPUT ACTIVATION SCALE + if self.quant_config.activation_scheme == "static": + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = torch.nn.Parameter(layer.weight.data, + requires_grad=False) + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight, + scale=None) + + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + assert weight_scale.numel() == 1 + weight_scale = convert_to_channelwise( + weight_scale.expand(len(layer.logical_widths)), + layer.logical_widths) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module + else: + layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, + requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = torch.nn.Parameter(layer.input_scale.data, + requires_grad=False) + # If using marlin (w8a16), kernel uses channelwise weights, + # so extend the weight scales to be channelwise. + if self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise(layer.weight_scale, + layer.logical_widths) + + # If using w8a8, torch._scaled_mm needs per tensor, so + # requantize the logical shards as a single weight. + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + + # If rocm, use float8_e4m3fnuz. + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + weight_scale, weight = requantize_with_max_scale( + weight=weight, + weight_scale=weight_scale, + logical_widths=layer.logical_widths, + ) + + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if self.quant_config.activation_scheme == "static": + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + if self.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + if self.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias) + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False) + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights(self, layer: Module, num_experts: int, hidden_size: int, + intermediate_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + 2 * intermediate_size, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size, + intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8.") + + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz \ + if is_hip() else torch.float8_e4m3fn + w13_weight = torch.empty_like(layer.w13_weight.data, + dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_weight_scale = torch.nn.Parameter(torch.ones( + layer.num_experts, + dtype=torch.float32, + device=w13_weight.device), + requires_grad=False) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w13_weight.data[expert, :, :]) + w2_weight[expert, :, :], layer.w2_weight_scale[ + expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :]) + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if (layer.w13_input_scale is None + or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + # If rocm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) + + +class Fp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Fp8Config): + super().__init__(quant_config) diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py new file mode 100644 index 0000000..d73b9f6 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -0,0 +1,178 @@ +from typing import Any, Dict, List, Optional + +import gguf +import torch +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.utils import set_weight_attrs + + +class GGUFConfig(QuantizationConfig): + """Config class for GGUF.""" + + def __init__(self, ) -> None: + pass + + def __repr__(self) -> str: + return ("GGUFConfig()") + + def get_name(self) -> str: + return "gguf" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] # no extra configs. + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GGUFConfig": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return GGUFLinearMethod(self) + elif isinstance(layer, VocabParallelEmbedding): + return GGUFEmbeddingMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, + qweight_type: int) -> torch.Tensor: + # use dequantize mulmat for IQmatrix, mmq for k-quants + if x.shape[0] == 1: + # enable mmvq in contiguous batching + y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0]) + elif qweight_type >= 16: + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) + weight = ops.ggml_dequantize(qweight, qweight_type, *shape) + y = x @ weight.T + else: + y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0]) + return y + + +class GGUFLinearMethod(LinearMethodBase): + """Linear method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def __init__(self, quant_config: GGUFConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + + tensor_shape = (output_size_per_partition, input_size_per_partition) + qweight = GGUFUninitializedParameter(requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 1, + "output_dim": 0, + "tensor_shape": tensor_shape, + "is_gguf_weight": True, + "data_container": [], + "shard_id": [], + "shard_id_map": {}, + }) + set_weight_attrs(qweight, extra_weight_attrs) + layer.register_parameter("qweight", qweight) + + qweight_type = Parameter(torch.empty(len(output_partition_sizes), + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight_type, { + "is_gguf_weight_type": True, + "weight_type": 0, + "shard_weight_type": {}, + "ignore_warning": True + }) + set_weight_attrs(qweight_type, extra_weight_attrs) + layer.register_parameter("qweight_type", qweight_type) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + shard_id = getattr(layer.qweight, "shard_id", None) + + if shard_id: + # dequantize shard weights respectively + shard_id = ["q", "k", "v"] if "q" in shard_id else shard_id + qweight = layer.qweight.unbind(0) + result = [] + for id in shard_id: + q_idx = layer.qweight.shard_id_map[id] + qweight_type = layer.qweight_type.shard_weight_type[id] + result.append(_fuse_mul_mat(x, qweight[q_idx], qweight_type)) + out = torch.cat(result, axis=1) + else: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + out = _fuse_mul_mat(x, qweight, qweight_type) + if bias is not None: + out.add_(bias) + return out + + +class GGUFEmbeddingMethod(GGUFLinearMethod): + """Embedding method for GGUF. + + Args: + quant_config: The GGUF quantization config. + """ + + def embedding(self, layer: torch.nn.Module, + x: torch.Tensor) -> torch.Tensor: + qweight = layer.qweight + qweight_type = layer.qweight_type.weight_type + + block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] + hidden_size = qweight.shape[1] // type_size * block_size + if qweight_type < 2: + return torch.embedding(qweight, x) + x_flat = x.flatten() + quant = torch.index_select(qweight, dim=0, index=x_flat) + dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, + x_flat.shape[0]) + return dequant.view(*x.shape, hidden_size) + + +class GGUFUninitializedParameter(UninitializedParameter): + cls_to_become = Parameter + data_container: List[torch.Tensor] + + def materialize_nested(self) -> Parameter: + nested_data = torch.nested.nested_tensor(self.data_container, + device=self.device, + dtype=torch.uint8) + self.data_container.clear() + param = torch.Tensor._make_subclass(self.cls_to_become, + nested_data, + require_grad=False) + for k, v in self.__dict__.items(): + setattr(param, k, v) + return param diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py new file mode 100644 index 0000000..bdd99d7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -0,0 +1,248 @@ +import enum +from enum import Enum +from fractions import Fraction +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + lm_head_quantized: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + self.pack_factor = Fraction(32, self.weight_bits) + if self.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is " + f"supported for GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})," + f"lm_head_quantized={self.lm_head_quantized}") + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, lm_head_quantized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return GPTQLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() + + +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. + + Args: + quant_config: The GPTQ quantization config. + """ + + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + output_size_per_partition = sum(output_partition_sizes) + if (output_size_per_partition % self.quant_config.pack_factor.numerator + != 0): + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExllamaState.UNINITIALIZED + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if (input_size != input_size_per_partition + and self.quant_config.group_size != -1): + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + exllama_state = ExllamaState.UNUSED + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + g_idx = RowvLLMParameter(data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + qzeros_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + layer.exllama_state = exllama_state + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if layer.exllama_state == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) + else: + layer.g_idx.data = torch.empty((0, ), + dtype=torch.int, + device=layer.g_idx.device) + layer.exllama_state = ExllamaState.READY + ops.gptq_shuffle(layer.qweight, layer.g_idx, + self.quant_config.weight_bits) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out_shape = x.shape[:-1] + (layer.qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + + output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros, + layer.scales, layer.g_idx, + layer.exllama_state == ExllamaState.READY, + self.quant_config.weight_bits) + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py new file mode 100644 index 0000000..e771917 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -0,0 +1,570 @@ +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kernels import ( + MPLinearLayerConfig, choose_mp_linear_kernel) +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + check_marlin_supported, marlin_moe_permute_scales, + marlin_repeat_scales_on_all_ranks, verify_marlin_supported) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + # (num_bits, is_sym) -> quant_type + TYPE_MAP = { + (4, True): scalar_types.uint4b8, + (8, True): scalar_types.uint8b128, + } + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.pack_factor = 32 // weight_bits # packed into int32 + self.group_size = group_size + self.desc_act = desc_act + self.lm_head_quantized = lm_head_quantized + + if (weight_bits, is_sym) not in self.TYPE_MAP: + raise ValueError("Unsupported quantization config: " + f"bits={weight_bits}, sym={is_sym}") + + self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + + def __repr__(self) -> str: + return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act}, " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(weight_bits, group_size, desc_act, is_sym, + lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) + + is_valid_user_quant = (user_quant is None or user_quant == "marlin" + or user_quant == "gptq_marlin") + + if can_convert and is_valid_user_quant: + msg = ("The model is convertible to {} during runtime." + " Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + if can_convert and user_quant == "gptq": + logger.info("Detected that the model can run with gptq_marlin" + ", however you specified quantization=gptq explicitly," + " so forcing gptq. Use quantization=gptq_marlin for" + " faster inference") + return None + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): + return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + quant_method = quant_config.get("quant_method", "").lower() + num_bits = quant_config.get("bits") + group_size = quant_config.get("group_size") + sym = quant_config.get("sym") + desc_act = quant_config.get("desc_act") + + if quant_method != "gptq": + return False + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + if (num_bits, sym) not in cls.TYPE_MAP: + return False + + return check_marlin_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)], + group_size=group_size) + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + # Verify supported on platform. + verify_marlin_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + output_size_per_partition = sum(output_partition_sizes) + is_row_parallel = input_size != input_size_per_partition + weight_loader = extra_weight_attrs.get("weight_loader") + + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=\ + (input_size_per_partition, output_size_per_partition), + weight_type=self.quant_config.quant_type, + act_type=params_dtype, + group_size=self.quant_config.group_size, + zero_points=False, + has_g_idx=self.quant_config.desc_act + ) + + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for GPTQMarlinLinearMethod", + kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Determine sharding + if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, + self.quant_config.group_size, + is_row_parallel): + # By setting scale_dim == None, weight_loader will + # repeat the scales on each GPU in TP>1 case. + scales_and_zp_input_dim = None + scales_and_zp_size = input_size // group_size + else: + # By setting scale_dim == 0, weight_loader will + # shard the scales in TP>1 case. + scales_and_zp_input_dim = 0 + scales_and_zp_size = input_size_per_partition // group_size + + # Quantized weights + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + # Activation order + g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + + qzeros_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + + if scales_and_zp_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + + self.kernel = kernel_type(mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + w_gidx_param_name="g_idx") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + strategy = FusedMoeWeightScaleSupported.GROUP.value + else: + scales_size13 = 1 + scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=None) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + layer.w13_scales, + layer.w2_scales, + router_logits, + topk_weights, + topk_ids, + g_idx1=layer.w13_g_idx, + g_idx2=layer.w2_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.quant_config.quant_type.size_bits, + ).to(orig_dtype) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py new file mode 100644 index 0000000..a66f12e --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -0,0 +1,295 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.scalar_type import scalar_types + +logger = init_logger(__name__) + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [ + scalar_types.uint4b8, scalar_types.uint8b128 +] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + + +class GPTQMarlin24Config(QuantizationConfig): + """Config class for Marlin24. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + ) -> None: + quant_type = { + 4: scalar_types.uint4b8, + 8: scalar_types.uint8b128, + }.get(weight_bits) + + self.group_size = group_size + + # Verify + if quant_type is None or \ + quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES: + raise ValueError( + f"Marlin_24 does not support quant_type = {quant_type}. " + f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin_24 does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} " + "are supported.") + + self.quant_type = quant_type + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.quant_type.size_bits + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "Marlin24Config(quant_type={}, group_size={})".format( + self.quant_type, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin_24" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + is_marlin_24_format = ( + hf_quant_cfg.get("checkpoint_format") == "marlin_24") + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "gptq_marlin_24") + + if is_marlin_24_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. " + "Using {} kernel.".format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["GPTQMarlin24LinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlin24LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class GPTQMarlin24LinearMethod(LinearMethodBase): + """Linear method for Marlin24. + + Args: + quant_config: The Marlin24 quantization config. + """ + + def __init__(self, quant_config: GPTQMarlin24Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size // 2, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + # Meta + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B_24", qweight) + layer.register_parameter("B_meta", meta) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B_24 = Parameter(layer.B_24.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B_24 + meta = layer.B_meta + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales, + workspace, + self.quant_config.quant_type, + size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py new file mode 100644 index 0000000..e540526 --- /dev/null +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -0,0 +1,166 @@ +from typing import Any, Dict, List, Optional + +import torch + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.awq import AWQLinearMethod +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.platforms import current_platform + + +class IPEXConfig(QuantizationConfig): + """INT8 quantization config class using IPEX for the CPU backend, + including AWQ. + """ + + IPEX_QUANT_METHOD_MAP = { + "awq": 1, + "gptq": 2, + } + + def __init__( + self, + method: str, + weight_bits: int, + group_size: int, + ) -> None: + self.method = method + self.weight_bits = weight_bits + self.group_size = group_size + self.pack_factor = 32 // self.weight_bits + + if self.weight_bits not in [4]: + raise ValueError(f"IPEX quantization supports weight bits [4], " + f"but got {self.weight_bits}.") + + if self.method == "awq": + self.quant_method = IPEXAWQLinearMethod + else: + raise ValueError(f"IPEX quantization supports [awq], " + f"but got {self.method}.") + + def __repr__(self) -> str: + return (f"IPEXConfig(method={self.method}" + f"weight_bits={self.weight_bits}, " + f"group_size={self.group_size}") + + def get_ipex_quant_method_id(self) -> int: + return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method] + + @classmethod + def get_name(cls) -> str: + return "ipex" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return -1 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig": + method = cls.get_from_keys(config, ["quant_method"]).lower() + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + return cls(method, weight_bits, group_size) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + if not current_platform.is_cpu(): + return None + + quant_method = hf_quant_cfg.get("quant_method", "").lower() + + if quant_method in ["awq"]: + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["LinearMethodBase"]: + if isinstance(layer, LinearBase): + return self.quant_method(self) + return None + + def get_scaled_act_names(self) -> List[str]: + if self.method == "awq": + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + else: + return [] + + +class IPEXAWQLinearMethod(AWQLinearMethod): + """AWQ linear method using IPEX for the CPU backend. + """ + + def __init__(self, quant_config: IPEXConfig): + self.quant_config = quant_config # type: ignore + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer=layer) + + bias = layer.bias if not layer.skip_bias_add else None + + try: + import intel_extension_for_pytorch as ipex + if ipex.__version__ < "2.4.0": + raise ImportError("intel_extension_for_pytorch version is " + "wrong. Please install " + "intel_extension_for_pytorch>=2.4.0.") + except ImportError as err: + raise ImportError( + "Please install " + "intel_extension_for_pytorch>=2.4.0 via " + "`pip install intel_extension_for_pytorch>=2.4.0`" + " to use IPEX-AWQ linear method.") from err + + # Using the compute dtype (lowp_mode) as INT8 to leverage instructions + # with better performance. + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + # The weight will be de-packed from INT4 to INT8. + weight_dtype = ipex.quantization.WoqWeightDtype.INT4 + # The float activation will be quantized (dynamic, per-token) to INT8. + act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( + weight_dtype=weight_dtype, + lowp_mode=lowp_mode, + act_quant_mode=act_quant_mode, + group_size=self.quant_config.group_size, + ) + + layer.ipex_output_size = layer.qweight.size( + 1) * self.quant_config.pack_factor + layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\ + WeightOnlyQuantizedLinear.from_weight( + layer.qweight, + layer.scales, + layer.qzeros, + layer.qweight.size(0), + layer.ipex_output_size, + qconfig=qconfig, + bias=bias, + group_size=self.quant_config.group_size, + quant_method= + self.quant_config.get_ipex_quant_method_id() # type: ignore + ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + out = layer.ipex_qlinear(reshaped_x) + + return out.reshape(x.shape[:-1] + (layer.ipex_output_size, )) diff --git a/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py new file mode 100644 index 0000000..fe50c49 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/MPLinearKernel.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module + ) -> Tuple[torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/vllm/model_executor/layers/quantization/kernels/__init__.py b/vllm/model_executor/layers/quantization/kernels/__init__.py new file mode 100644 index 0000000..47591c2 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/__init__.py @@ -0,0 +1,72 @@ +import os +from typing import List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.machete import ( + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.marlin import ( + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import ( + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/vllm/model_executor/layers/quantization/kernels/__pycache__/MPLinearKernel.cpython-310.pyc b/vllm/model_executor/layers/quantization/kernels/__pycache__/MPLinearKernel.cpython-310.pyc new file mode 100644 index 0000000..caa2d19 Binary files /dev/null and b/vllm/model_executor/layers/quantization/kernels/__pycache__/MPLinearKernel.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/kernels/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/quantization/kernels/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..72e7450 Binary files /dev/null and b/vllm/model_executor/layers/quantization/kernels/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/kernels/__pycache__/machete.cpython-310.pyc b/vllm/model_executor/layers/quantization/kernels/__pycache__/machete.cpython-310.pyc new file mode 100644 index 0000000..6a67b05 Binary files /dev/null and b/vllm/model_executor/layers/quantization/kernels/__pycache__/machete.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/kernels/__pycache__/marlin.cpython-310.pyc b/vllm/model_executor/layers/quantization/kernels/__pycache__/marlin.cpython-310.pyc new file mode 100644 index 0000000..478759a Binary files /dev/null and b/vllm/model_executor/layers/quantization/kernels/__pycache__/marlin.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py new file mode 100644 index 0000000..fa39cb5 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -0,0 +1,118 @@ +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_weights_into_int32, unpack_weights_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_weights_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_weights_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + self.config.weight_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_gemm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_zeros=None, + b_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/kernels/marlin.py b/vllm/model_executor/layers/quantization/kernels/marlin.py new file mode 100644 index 0000000..6969583 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/marlin.py @@ -0,0 +1,133 @@ +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py new file mode 100644 index 0000000..d79536d --- /dev/null +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -0,0 +1,76 @@ +import torch + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.utils import print_warning_once + + +class BaseKVCacheMethod(QuantizeMethodBase): + """ + Quant method that adds `_k_scale` and `_v_scale` attributes to the + Attention layer to support loading those scaling factors from checkpoints. + The k/v_scale will be used to: + - quantize k/v_cache entries before saving them to the cache + - dequantize k/v_cache entries before fetching them from the cache + + :param quant_config: the appropriate QuantizationConfig + """ + + def __init__(self, quant_config: QuantizationConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """ + Create "weight" (aka k_scale and v_scale) for an attention layer. + """ + # Initialize the KV cache scales to -1.0, which is an invalid value. + # If the k/v_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError( + f"{self.__class__.__name__}.apply should not be called.") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + if layer.kv_cache_dtype != "auto": + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + + if not isinstance(k_scale, float) or not isinstance( + v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + # These are used in the final Attention.forward() + layer._k_scale = k_scale + layer._v_scale = v_scale + if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + and "e5m2" not in layer.kv_cache_dtype): + print_warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + del layer.k_scale + del layer.v_scale diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py new file mode 100644 index 0000000..c655a76 --- /dev/null +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -0,0 +1,260 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) + +logger = init_logger(__name__) + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + lm_head_quantized: bool, + ) -> None: + # Group size for the quantization. + self.group_size = group_size + self.lm_head_quantized = lm_head_quantized + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) " + "is supported for Marlin, but got group_size of " + f"{self.group_size}") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return (f"MarlinConfig(group_size={self.group_size}, " + f"lm_head_quantized={self.lm_head_quantized})") + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], + default=False) + return cls(group_size, lm_head_quantized) + + @classmethod + def override_quantization_method(cls, hf_quant_cfg, + user_quant) -> Optional[str]: + # compat: autogptq >=0.8.0 use checkpoint_format: str + # compat: autogptq <=0.7.1 is_marlin_format: bool + is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" + or hf_quant_cfg.get("is_marlin_format", False)) + + is_valid_user_quant = (user_quant is None or user_quant == "gptq" + or user_quant == "marlin") + + if is_marlin_format and is_valid_user_quant: + msg = ("The model is serialized in {} format. Using {} kernel.". + format(cls.get_name(), cls.get_name())) + logger.info(msg) + return cls.get_name() + + return None + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["MarlinLinearMethod"]: + if (isinstance(layer, LinearBase) or + (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + return MarlinLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused. + weight_loader = extra_weight_attrs["weight_loader"] + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + # Determine if channelwise or not + input_groups = (1 if self.quant_config.group_size == -1 else + input_size_per_partition // + self.quant_config.group_size) + + weight_scale_args = { + "data": + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B", qweight) + layer.register_parameter("s", scales) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + scales = layer.s + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py new file mode 100644 index 0000000..dc5f47e --- /dev/null +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -0,0 +1,163 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale) +from vllm.model_executor.parameter import (ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + +ACTIVATION_SCHEMES = ["static"] + + +class ModelOptFp8Config(QuantizationConfig): + """Config class for ModelOpt FP8.""" + + def __init__( + self, + is_checkpoint_fp8_serialized: bool = False, + ) -> None: + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning("Detected ModelOpt fp8 checkpoint. Please note that" + " the format is experimental and could change.") + + @classmethod + def get_name(cls) -> str: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + is_checkpoint_fp8_serialized = ("FP8" in quant_method) + if not is_checkpoint_fp8_serialized: + raise ValueError("ModelOpt currently only supports static FP8" + "quantization in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + return cls(is_checkpoint_fp8_serialized) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + return ModelOptFp8LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + super().__init__(quant_config) + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer static quantization. + Supports loading FP8 checkpoints with static weight scale and + activation scale. Future support might be added for dynamic + scales. + + Limitations: + 1. Only support per-tensor quantization due to torch._scaled_mm support. + 2. Only support float8_e4m3fn datatype + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized else + params_dtype) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + if self.quant_config.is_checkpoint_fp8_serialized: + # WEIGHT SCALE + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + # INPUT SCALE + scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", scale) + + def process_weights_after_loading(self, layer: Module) -> None: + max_w_scale, weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths) + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported) diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py new file mode 100644 index 0000000..2624981 --- /dev/null +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -0,0 +1,67 @@ +import os +from importlib.util import find_spec +from typing import Any, Dict, List, Optional + +from torch.nn import Module + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] + + +class NeuronQuantConfig(QuantizationConfig): + """Int8 Quantization Config class for Neuron Backend.""" + + def __init__( + self, + dequant_dtype: str = "f16", + quantize_method: str = "vector_dynamic", + ) -> None: + self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") + if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: + raise ValueError( + f"Neuron quantization datatype {self.quant_dtype} is not valid," + f"the quantization datatype should match one of the below types" + f"{SUPPORTED_QUANT_DTYPE_LIST}") + self.dequant_dtype = dequant_dtype + self.quantize_method = quantize_method + + def get_name(self) -> str: + return "neuron_quant" + + def get_supported_act_dtypes(self) -> List[str]: + return SUPPORTED_QUANT_DTYPE_LIST + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with Neuron Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": + quantize_method = cls.get_from_keys(config, ["quantize_method"]) + dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) + return cls(dequant_dtype=dequant_dtype, + quantize_method=quantize_method) + + def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: + if find_spec("transformers_neuronx") is not None: + return self.get_quantization_config() + else: + raise NotImplementedError( + "Neuron Quantization is only supported through" + " transformers_neuronx.") + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_quantization_config(self): + from transformers_neuronx.config import QuantizationConfig + return QuantizationConfig(quant_dtype=self.quant_dtype, + dequant_dtype=self.dequant_dtype, + quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py new file mode 100644 index 0000000..b88ecac --- /dev/null +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -0,0 +1,273 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) + +logger = init_logger(__name__) + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + + +class QQQConfig(QuantizationConfig): + """Config class for QQQ + + Reference: https://arxiv.org/pdf/2406.09904 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + is_sym: bool = True, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.is_sym = is_sym + + # Verify + if self.weight_bits not in MARLIN_QQQ_SUPPORTED_NUM_BITS: + raise ValueError( + f"QQQ does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {MARLIN_QQQ_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in MARLIN_QQQ_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"QQQ does not support group_size = {self.group_size}. " + f"Only group_sizes = {MARLIN_QQQ_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in MARLIN_QQQ_SUPPORTED_SYM: + raise ValueError( + f"QQQ does not support is_sym = {self.is_sym}. " + f"Only sym = {MARLIN_QQQ_SUPPORTED_SYM} are supported.") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // self.weight_bits + + # Tile size used by QQQ kernels. + self.tile_size = MARLIN_QQQ_TILE + + # Min out_features dim + self.min_n_threads = MARLIN_QQQ_MIN_THREAD_N + + # Min in_features dim + self.min_k_threads = MARLIN_QQQ_MIN_THREAD_K + + # Max parallel problems to solve at once (improves large + # batch performance) + self.max_parallel = MARLIN_QQQ_MAX_PARALLEL + + # Permutation length used by the QQQ kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return "QQQConfig(weight_bits={}, group_size={})".format( + self.weight_bits, self.group_size) + + @classmethod + def get_name(cls) -> str: + return "qqq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QQQConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(weight_bits, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QQQLinearMethod"]: + if isinstance(layer, LinearBase): + return QQQLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class QQQLinearMethod(LinearMethodBase): + """Linear method for QQQ. + + Args: + quant_config: The QQQ quantization config. + """ + + def __init__(self, quant_config: QQQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + weight_loader = extra_weight_attrs["weight_loader"] + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"min_n_threads = {self.quant_config.min_n_threads}.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f"pack_factor = {self.quant_config.pack_factor}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"min_k_threads = {self.quant_config.min_k_threads}.") + if (self.quant_config.group_size != -1 and + input_size_per_partition % self.quant_config.group_size != 0): + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"group_size = {self.quant_config.group_size}.") + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + s_channel = ChannelQuantScaleParameter(data=torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + weight_loader=weight_loader, + output_dim=1) + + if self.quant_config.group_size == -1: + s_group_data = torch.tensor( + [], + device="cuda", + dtype=torch.half, + ) + else: + s_group_data = torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, + ) + + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} + + if self.quant_config.group_size == -1: + s_group = BasevLLMParameter(**s_group_attr) + else: + s_group = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **s_group_attr) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) + + layer.register_parameter("B", qweight) + layer.register_parameter("s_channel", s_channel) + layer.register_parameter("s_group", s_group) + layer.register_parameter("workspace", workspace) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.B + s_ch = layer.s_channel + s_group = layer.s_group + workspace = layer.workspace + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = s_ch.shape[1] + + x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d) + + output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group, + workspace, size_m, size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/schema.py b/vllm/model_executor/layers/quantization/schema.py new file mode 100644 index 0000000..a26c524 --- /dev/null +++ b/vllm/model_executor/layers/quantization/schema.py @@ -0,0 +1,84 @@ +""" +This file contains the Pydantic schemas for various quantization-related +parameters. When a relevant quantization technique is specified, these +parameters are loaded in the form of a JSON alongside the model weights +and augment the model with additional information needed for use of that +technique. The format of this JSON should be specified by one or more +schemas contained here. + +For example, when the KV cache is quantized to FP8-E4M3 (currently only +possible on ROCm), the model can be optionally augmented with KV cache +scaling factors. +""" + +from typing import Dict, Optional + +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator + + +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!") + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}.") + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}.") + for i in range(tp_size): + assert i in self.scaling_factor, ( + f"KV cache scales map for TP rank {i} not found.") + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}.") + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!") + return self diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py new file mode 100644 index 0000000..be8235b --- /dev/null +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -0,0 +1,119 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import ModelWeightParameter + +ACTIVATION_SCHEMES = ["none"] + + +class Int8TpuConfig(QuantizationConfig): + """Int8 Quantization Config class for TPU Backend.""" + + def __init__( + self, + activation_scheme: str = "none", + ) -> None: + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + def get_name(self) -> str: + return "tpu_int8" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with TPU Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme=activation_scheme) + + def get_quant_method(self, layer: Module, + prefix: str) -> Optional["TPUInt8LinearMethod"]: + if isinstance(layer, LinearBase): + return TPUInt8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class TPUInt8LinearMethod(LinearMethodBase): + """Int8 Linear method for TPU Quant. """ + + def __init__(self, quant_config: Int8TpuConfig): + self.quant_config = quant_config + + def create_weights(self, layer: Module, input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + def _quantize_weight( + self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weight_dtype = weight.dtype + weight = weight.cpu().to(torch.float32) + n_bit = 8 + eps = 1e-5 + max_int = 2**(n_bit - 1) - 1 + min_int = -(2**(n_bit - 1)) + max_val = weight.abs().amax(dim=-1, keepdim=True) + max_val = max_val.clamp(min=eps) + qscale = max_val / max_int + qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int, + max_int).to(torch.int8) + qscale = qscale.squeeze().to(weight_dtype) + return qweight, qscale + + def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) + device = layer.weight.device + qweight, qscale = self._quantize_weight(layer.weight) + qweight = qweight.to(device) + qscale = qscale.to(device) + layer.weight = Parameter(qweight, requires_grad=False) + layer.scale = Parameter(qscale, requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + try: + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + except ImportError as err: + raise ImportError( + "Please install torch_xla by following the instructions at " + "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 + "to run vLLM on TPU.") from err + weight = layer.weight + scale = layer.scale + out = torch.ops.xla.quantized_matmul(x, weight, scale) + if bias is not None: + out = out + bias + return out diff --git a/vllm/model_executor/layers/quantization/utils/__init__.py b/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000..e60f0c7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,3 @@ +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..283737e Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-310.pyc new file mode 100644 index 0000000..9c7dfc0 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/machete_utils.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/machete_utils.cpython-310.pyc new file mode 100644 index 0000000..9ec857a Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/machete_utils.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils.cpython-310.pyc new file mode 100644 index 0000000..338b186 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_fp8.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_fp8.cpython-310.pyc new file mode 100644 index 0000000..cc9b2e8 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_fp8.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test.cpython-310.pyc new file mode 100644 index 0000000..f8cb1bc Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test_24.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test_24.cpython-310.pyc new file mode 100644 index 0000000..16f3a63 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test_24.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test_qqq.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test_qqq.cpython-310.pyc new file mode 100644 index 0000000..119fea0 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils_test_qqq.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/quant_utils.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/quant_utils.cpython-310.pyc new file mode 100644 index 0000000..076fe51 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/quant_utils.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/__pycache__/w8a8_utils.cpython-310.pyc b/vllm/model_executor/layers/quantization/utils/__pycache__/w8a8_utils.cpython-310.pyc new file mode 100644 index 0000000..14fcc99 Binary files /dev/null and b/vllm/model_executor/layers/quantization/utils/__pycache__/w8a8_utils.cpython-310.pyc differ diff --git a/vllm/model_executor/layers/quantization/utils/layer_utils.py b/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000..edce6d1 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,37 @@ +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if type(old) is type(new) and old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) diff --git a/vllm/model_executor/layers/quantization/utils/machete_utils.py b/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000..18e1332 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,30 @@ +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000..9a1defa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,348 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_marlin_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000..8b3dfaa --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,106 @@ +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + return current_platform.has_device_capability(80) + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, + strategy: str = "tensor") -> None: + print_warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( + layer.weight), + perm=torch.empty(0, + dtype=torch.int, + device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales(s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = (byte_tensor[:, 0].to(torch.int32) | + (byte_tensor[:, 1].to(torch.int32) << 8) | + (byte_tensor[:, 2].to(torch.int32) << 16) | + (byte_tensor[:, 3].to(torch.int32) << 24)) + + return packed.view(fp8_tensor.shape[0] // 4, + *fp8_tensor.shape[1:]).contiguous() diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000..4a06c5d --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -0,0 +1,163 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, + marlin_zero_points) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000..17d0905 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,463 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + quant_type.size_bits, weight_perm) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000..cb58eb9 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000..833d000 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,451 @@ +"""This file is used for /tests and /benchmarks""" +from typing import List, Optional + +import numpy +import torch + +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] +} + + +def pack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_weights_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Reshape to [groupsize, -1] + if group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + maybe_w_zp = None + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and zero_points: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + + w_s = w_s.reshape((-1, size_n)).contiguous() + + if zero_points: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s.to(device=orig_device), + maybe_w_zp, + ) + + +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / + s_channel).to(dtype=torch.half) + else: + max_q_val = 2**(num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= (2**(8 - num_bits)) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, size_n // pack_factor + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py new file mode 100644 index 0000000..411af92 --- /dev/null +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -0,0 +1,246 @@ +from typing import List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.utils import is_hip + +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None + + +def cutlass_fp8_supported() -> bool: + # cutlass is not supported on Rocm + if is_hip(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def convert_to_channelwise( + weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Create channelwise buffer + weight_scale_channel = torch.empty((sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device) + + # Expand each scale to match the size of each logical matrix. + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = weight_scale[idx] + start = end + + return weight_scale_channel + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = (weight_scale[-1] > torch.finfo( + torch.float8_e4m3fn).min) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant( + weight_dq, max_w_scale) + start = end + + return max_w_scale, weight + + +def apply_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_fp8_supported: bool = True, + use_per_token_if_dynamic: bool = False, +) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + # Fused GEMM_DQ + return ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + else: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + qinput, x_scale = ops.scaled_fp8_quant( + input, + input_scale, + num_token_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + return torch.narrow(output[0], 0, 0, input.shape[0]) + return torch.narrow(output, 0, 0, input.shape[0]) + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if (TORCH_DEVICE_IDENTITY is not None + and TORCH_DEVICE_IDENTITY.device != weight.device): + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype) + + +def apply_int8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_zero_point: Optional[torch.Tensor] = None, + azp_adj: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +): + # ops.scaled_int8_quant supports both dynamic and static quant. + # * dynamic, layer.input_scale is None and x_scale computed from x. + # * static, layer.input_scale is scalar and x_scale is input_scale. + symmetric = azp_adj is None + x_q, x_scale, x_zp = ops.scaled_int8_quant(input, + input_scale, + input_zero_point, + symmetric=symmetric) + + if x_zp is not None: + return ops.cutlass_scaled_mm_azp(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + azp_adj=azp_adj, + azp=x_zp, + bias=bias) + return ops.cutlass_scaled_mm(x_q, + weight, + scale_a=x_scale, + scale_b=weight_scale, + out_dtype=input.dtype, + bias=bias) + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale diff --git a/vllm/model_executor/layers/quantization/w8a16.py b/vllm/model_executor/layers/quantization/w8a16.py new file mode 100644 index 0000000..6c42ce7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/w8a16.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.parameter import (GroupQuantScaleParameter, + PackedvLLMParameter) +from vllm.model_executor.utils import set_weight_attrs + + +class W8a16Config(QuantizationConfig): + """Config class for W8a16. + + """ + + def __init__( + self, + ) -> None: + pass + + def __repr__(self) -> str: + return ("W8a16Config") + + def get_name(self) -> str: + return "w8a16" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + def get_min_capability(self) -> int: + return 75 + + @staticmethod + def get_config_filenames(): + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8a16Config": + return cls() + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["W8a16LinearMethod"]: + if isinstance(layer, LinearBase): + return W8a16LinearMethod(self) + return None + + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8a16LinearMethod(LinearMethodBase): + """Linear method for w8a16. + + """ + + def __init__(self, quant_config: W8a16Config): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + output_size_per_partition = sum(output_partition_sizes) + weight = Parameter( + torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + set_weight_attrs( + weight, { + "input_dim": 1, + "output_dim": 0, + }) + + scales = Parameter( + torch.empty( + 1, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": None, + "output_dim": 1, + }) + + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = layer.weight + scales = layer.scales + out_shape = (x.shape[:-1] + (qweight.shape[-2],)) + reshaped_x = x.reshape(-1, x.shape[-1]) + out = ops.linear_w8a16(reshaped_x, qweight, scales, format="TN") + if bias is not None: + out = out + bias + return out.reshape(out_shape) \ No newline at end of file diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 0000000..1137b51 --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,401 @@ +from functools import cached_property +from importlib.util import find_spec +from typing import Dict, List, Optional, Tuple + +import torch +import torch.jit + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeStochasticBaseSampler) + +logger = init_logger(__name__) + +if find_spec("flashinfer"): + """ + Consider utilizing the FlashInfer rejection sampling kernel initially, + as it employs a dedicated kernel rather than relying on + Torch tensor operations. This design choice helps to fuse operations, + reduce memory I/O, and consequently enhances performance. + """ + from flashinfer.sampling import chain_speculative_sampling +else: + chain_speculative_sampling = None + + +class RejectionSampler(SpecDecodeStochasticBaseSampler): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, + strict_mode: bool = False, + use_flashinfer: Optional[bool] = None): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + use_falshinfer: We will use this parameter to determine whether + to use the FlashInfer rejection sampling kernel or not. If it's + None, we will use the default value from the environment variable. + This parameter is only used for testing purposes. + """ + super().__init__(strict_mode=strict_mode) + if use_flashinfer is None: + self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( + chain_speculative_sampling is not None) + else: + self.use_flashinfer = use_flashinfer + + if self.use_flashinfer: + logger.info("Use flashinfer for rejection sampling.") + else: + logger.info("Use pytorch for rejection sampling.") + + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + seeded_seqs: Optional[Dict[int, torch.Generator]] = None, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_with_bonus_probs: The probability distribution + over token ids given context according to the target model. + shape = [batch_size, num_speculative_tokens + 1, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + seeded_seqs: Dict of batch row index to torch generator, for + sequences using seeded generation. + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + + batch_size, k, _ = draft_probs.shape + + # batch_size = 0 when all requests in the batch are + # non_spec requests. In this case, output_token_ids is + # just an empty tensor. + if batch_size == 0: + return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) + + # If use Flashinfer chain_speculative_sampling kernel + # for rejection sampling + if self.use_flashinfer: + batch_size, k, _ = draft_probs.shape + uniform_samples = self._create_uniform_samples( + seeded_seqs, batch_size, k, draft_probs.device) + output_token_ids, accepted_token_num, emitted_token_num \ + = chain_speculative_sampling( + draft_probs, draft_token_ids, uniform_samples, + target_with_bonus_probs) + + # num_emitted_tokens returned by flashinfer + # does not include the bonus token + # Flashinfer stops at the first token that violates + # the condition p >= q and does not include recovery/bonus token. + # Therefore, we need to add batch_size here. + self.num_accepted_tokens += accepted_token_num.sum() + self.num_emitted_tokens += emitted_token_num.sum() + batch_size + self.num_draft_tokens += batch_size * k + else: + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_with_bonus_probs[:, :-1], + draft_probs, + draft_token_ids, + seeded_seqs, + )) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + seeded_seqs: Optional[Dict[int, torch.Generator]], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids, seeded_seqs) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + # NOTE: the recovered_probs are overwritten by this method. + recovered_token_ids = _multinomial( + recovered_probs, + num_samples=1, + k=k, + seeded_seqs=seeded_seqs or {}, + ).reshape(batch_size, k) + + return accepted, recovered_token_ids + + def _create_uniform_samples(self, + seeded_seqs: Optional[Dict[int, + torch.Generator]], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k + 1)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. If `None`, all samples are + generated without a seed. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k + 1)` containing uniform + random values in the range [0, 1). + """ + if not seeded_seqs: + return torch.rand(batch_size, k + 1, device=device) + + uniform_rand = torch.empty(batch_size, k + 1, device=device) + + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand(1, + k + 1, + dtype=self.probs_dtype, + device=device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), + k + 1, + dtype=self.probs_dtype, + device=device) + return uniform_rand + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + seeded_seqs: Optional[Dict[int, torch.Generator]], + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of + :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according + to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + same conditional probability according to the draft model, the token + is accepted with probability: + + .. math:: + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indicies = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indicies, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indicies, + draft_token_ids] + + uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, + k - 1, target_probs.device) + + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of + :math:`x` given context :math:`x_1, \dots, x_n` according to the target + model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability + according to the draft model: + + .. math:: + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + + where :math:`(f(x))_+` is defined as: + + .. math:: + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + + See https://github.com/vllm-project/vllm/pull/2336 for a visualization + of the draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +# @torch.jit.script +def _multinomial( + probs: torch.Tensor, + num_samples: int, + k: int, + seeded_seqs: Dict[int, torch.Generator], +) -> torch.Tensor: + + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs) + if not seeded_seqs: + q.exponential_(1.0) + else: + non_seeded_indices: List[int] = [] + start = 0 + for idx in range(len(q) // k): + end = start + k + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.extend(list(range(start, end))) + else: + q[start:end].exponential_(1.0, generator=generator) + start = end + q[non_seeded_indices].exponential_(1.0) + + return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py new file mode 100644 index 0000000..1a70d0f --- /dev/null +++ b/vllm/model_executor/layers/resampler.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +# +# Copyright 2023 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared resampler perceiver network used in multimodal models and +related helpers for sincos positional embeddings. + +Example models: Qwen (Qwen-VL), Minicpmv2.0 +""" +import math +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import normal_ + +from vllm.model_executor.layers.linear import ReplicatedLinear + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, + int]) -> torch.Tensor: + # abs_pos: L, C + # tgt_size: (H, W) + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + dtype = abs_pos.dtype + if isinstance(tgt_size, int): + tgt_size = (tgt_size, tgt_size) + if (src_size == tgt_size[0] and src_size == tgt_size[1]): + return abs_pos + return (F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and \ + grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + ) -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + normal_(self.query, std=0.02) + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = nn.Parameter( + (embed_dim**-0.5) * + torch.randn(embed_dim, embed_dim)) if do_post_projection else None + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + """Resampler-perceiver network to be used for a variety of model types, + e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the + do_post_projection arg, which indicates whether or not there should be + a post layer normalization and projector after the attention. This is + present in minicpmv2.0, but not qwen-vl. + """ + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + ) -> None: + super().__init__(grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection) + + self.adaptive = adaptive + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).requires_grad_(False)) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if tgt_sizes is None: + tgt_sizes = int(math.sqrt(x.size(1))) + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) + else: + pos_embed = get_abs_pos(self.pos_embed, + tgt_sizes).to(device=x.device, + dtype=x.dtype) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + if self.do_post_projection: + x = self.ln_post(x) + x = x @ self.proj + return x diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py new file mode 100644 index 0000000..847e834 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -0,0 +1,1006 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.model_executor.custom_op import CustomOp + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factors: Union[List[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: List[float] = scaling_factors # noqa + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ + rotary_dim != head_size ({rotary_dim}!={head_size}).") + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(dtype) + self.register_buffer("short_cos_sin_cache", + short_cache, + persistent=False) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(dtype) + self.register_buffer("long_cos_sin_cache", + long_cache, + persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + self.long_short_cos_sin_cache: torch.Tensor = ( + self.long_short_cos_sin_cache.to(idx.device)) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @staticmethod + def get_input_positions( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + ) -> Tuple[List[List[int]], int]: + """Get mrope input positions and delta value.""" + + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> List[List[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling_args, dtype) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + else: + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] + scaling_type = "mrope" if scaling_type == "default" else scaling_type + # The correct one should be "longrope" but keep "su" here + # for backward compatible + if scaling_type not in {"su", "longrope"}: + scaling_factor = rope_scaling.get("factor", 1.0) + if scaling_type == "llama3": + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "linear": + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype) + elif scaling_type == "dynamic": + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor, dtype) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, dtype, + **extra_kwargs) + elif scaling_type == "deepseek_yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow", "mscale", "mscale_all_dim") + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) + # The correct one should be "longrope" but keep "su" here + # for backward compatible + elif scaling_type == "su" or scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) + elif scaling_type == "mrope": + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py new file mode 100644 index 0000000..42a6a0e --- /dev/null +++ b/vllm/model_executor/layers/sampler.py @@ -0,0 +1,1317 @@ +"""A layer that samples the next tokens from the model's outputs.""" +import itertools +import warnings +from dataclasses import dataclass +from importlib.util import find_spec +from math import inf +from typing import Dict, List, Optional, Tuple, Union + +import msgspec +import torch +import torch.nn as nn + +import vllm.envs as envs +from vllm.model_executor.sampling_metadata import (SamplingMetadata, + SamplingTensors, + SequenceGroupToSample) +from vllm.sampling_params import SamplingType +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SampleLogprobs, SequenceOutput) +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): + import flashinfer.sampling + # yapf: disable + from flashinfer.sampling import ( + top_k_top_p_sampling_from_probs as flashinfer_top_k_top_p_sampling) + + # yapf: enable +else: + flashinfer_top_k_top_p_sampling = None + +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + +# Types of temporary data structures used for +# computing sample_result +SampleMetadataType = Dict[SamplingType, Tuple[List[int], + List[SequenceGroupToSample]]] +MultinomialSamplesType = Dict[SamplingType, torch.Tensor] +SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] + + +# Encapsulates temporary data structures for computing +# sample_result. +# +# * For multi-step scheduling: must be returned +# by `Sampler.forward()` and used later to compute the pythonized +# sample_result +# +# * For single-step scheduling: consumed immediately +# inside `Sampler.forward()` to compute pythonized sample_result. +@dataclass +class SampleResultArgsType: + sample_metadata: SampleMetadataType + multinomial_samples: MultinomialSamplesType + sample_results_dict: SampleResultsDictType + sampling_metadata: SamplingMetadata + greedy_samples: Optional[torch.Tensor] + beam_search_logprobs: Optional[torch.Tensor] + + +# Union of non-deferred (single-step scheduling) +# vs deferred (multi-step scheduling) +# sample result types +MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] + +# Abbreviation of the _sample() return type +SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] + + +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This data structure implements methods, so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: List[CompletionSequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional[torch.Tensor] = None + + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + + # Holds either (1) the pythonized sampler result (single-step scheduling) + # or (2) what will be arguments for later deferred pythonization of the + # sampler result (muliti-step scheduling) + deferred_sample_results_args: Optional[SampleResultArgsType] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional[torch.Tensor] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + + +class Sampler(nn.Module): + """Samples the next tokens from the model's outputs. + + This layer does the following: + 1. Discard the hidden states that are not used for sampling (i.e., all + tokens except the final one in each prompt). + 2. Compute the logits for the next tokens. + 3. Apply presence, frequency and repetition penalties. + 4. Apply temperature scaling. + 5. Apply top-p and top-k truncation. + 6. Sample the next tokens. + Here, each sequence group within the batch can have different sampling + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + + The structure of the logits tensor is coupled with the seq_groups in + sampling_metadata. Typically, each sequence in each seq_group has one row in + logits for the next token to be sampled; however, for a seq_group with a + prompt request with the prompt_logprobs sampling parameter, there are rows + in logits for each token in the input prompt. + """ + + def __init__(self): + super().__init__() + + # Whether or not the SamplerOutput should have on-device tensors + # containing the sampled token ids and probabilities. This is used by + # speculative decoding. + self.include_gpu_probs_tensor = False + self.should_modify_greedy_probs_inplace = False + + def _init_sampling_tensors( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ): + """The goal here is to reuse sampling tensors between similar decode + runs. This is possible because sampling logic does not change between + decodes of the same sequences. + """ + _, vocab_size = logits.shape + + # First free any existing stored sampling tensors. + # This is necessary because some sampling tensors may + # have pinned memory. + self._sampling_tensors = None + + # Initialize new sampling tensors + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + self._sampling_tensors = sampling_tensors + self._do_penalties = do_penalties + self._do_top_p_top_k = do_top_p_top_k + self._do_min_p = do_min_p + + def forward( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + """ + Single-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor + + Multi-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the :class:`SamplerOutput` structure + + Args: + logits: (num_tokens, vocab_size). + sampling_metadata: Metadata for sampling. + """ + assert logits is not None + _, vocab_size = logits.shape + + # Prepare sampling tensors with pinned memory to avoid blocking. + if not sampling_metadata.reuse_sampling_tensors: + self._init_sampling_tensors(logits, sampling_metadata) + elif self._do_penalties: + # In this case, the sampling tensors logic depends on + # "output_tokens" of a sequence. As a result, we cannot + # reuse sampling tensors, since "output_tokens" changes + # between decode runs. + self._init_sampling_tensors(logits, sampling_metadata) + + assert self._sampling_tensors is not None + sampling_tensors = self._sampling_tensors + do_penalties = self._do_penalties + do_top_p_top_k = self._do_top_p_top_k + do_min_p = self._do_min_p + + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + + # Apply presence and frequency penalties. + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Use float32 to apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits = logits.to(torch.float) + logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) + + if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=self.include_gpu_probs_tensor, + modify_greedy_probs=self._should_modify_greedy_probs_inplace, + ) + + if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs + assert maybe_sampled_tokens_tensor is not None + on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) + else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. + on_device_tensors = None + + # Get the logprobs query results. + prompt_logprobs = None + sample_logprobs = None + if not sampling_metadata.skip_sampler_cpu_output: + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) + + return _build_sampler_output( + maybe_deferred_sample_results, + sampling_metadata, + prompt_logprobs, + sample_logprobs, + on_device_tensors=on_device_tensors, + skip_sampler_cpu_output=sampling_metadata.skip_sampler_cpu_output) + + @property + def _should_modify_greedy_probs_inplace(self) -> bool: + """Whether or not the sampler should modify the probability distribution + of greedily-sampled tokens such that multinomial sampling would sample + the greedily-sampled token. + + In other words, if True then we set the probability of the greedily- + sampled token to 1. + + This is used by speculative decoding, which requires that the sampling + method be encoded into the probability distribution. + """ + return self.should_modify_greedy_probs_inplace + + +def _get_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def _apply_min_tokens_penalty( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + """Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + have not been generated yet + """ + # list of indices in logits that will be set to -inf + logits_to_penalize: List[Tuple[int, int]] = [] + logits_applied = 0 + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + + sample_indices = seq_group.sample_indices + logits_applied += len(sample_indices) + len( + seq_group.prompt_logprob_indices) + if not seq_group.do_sample: + continue + + start_idx = sample_indices[0] + min_tokens = sampling_params.min_tokens + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: + seqs_to_penalize: List[int] = [] + for j, seq_id in enumerate(seq_ids): + seq_data = seq_group.seq_data[seq_id] + if len(seq_data.output_token_ids_array) < min_tokens: + seqs_to_penalize.append(j) + + if seqs_to_penalize: + # convert to the index into logits + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] + # itertools.product pairs each seq index with every token id + logits_to_penalize.extend( + itertools.product(seqs_to_penalize, token_ids_to_penalize)) + + if logits_to_penalize: + # use zip and * to group indices along each dimension + # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) + logits[tuple(zip(*logits_to_penalize))] = -float("inf") + + # verifies that no rows in logits were missed unexpectedly + assert logits_applied == logits.shape[0] + return logits + + +def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + num_seqs, vocab_size = logits.shape + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) + output_bin_counts, output_mask = _get_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, + logits * repetition_penalties) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + return logits + + +def _apply_top_k_top_p( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) + return logits + + +def _apply_min_p( + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Adapted from + https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 + """ + probs = torch.softmax(logits, dim=-1) + top_probs, _ = probs.max(dim=-1, keepdim=True) + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs + tokens_to_remove = probs < scaled_min_p + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) + + return logits + + +def _greedy_sample( + selected_seq_groups: List[SequenceGroupToSample], + samples: torch.Tensor, +) -> SampleResultType: + """Run greedy sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + samples_lst = samples.tolist() + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + num_parent_seqs = len(seq_ids) + assert num_parent_seqs == 1, ( + "Greedy sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + next_token_ids = [samples_lst[sample_idx]] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +def _random_sample( + selected_seq_groups: List[SequenceGroupToSample], + random_samples: torch.Tensor, +) -> SampleResultType: + """Run random sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + random_samples: (num_selected_samples,) A tensor of samples. The + length of samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + # Find the maximum n value of the prompt phase requests. + random_samples = random_samples.cpu() + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + num_parent_seqs = len(seq_ids) + if is_prompt: + # Prompt phase. + parent_ids = [0] * sampling_params.n + next_token_ids = random_samples[ + sample_idx, :sampling_params.n].tolist() + else: + # Generation phase. + parent_ids = list(range(num_parent_seqs)) + next_token_ids = random_samples[sample_idx:sample_idx + + num_parent_seqs, 0].tolist() + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +def _beam_search_sample( + selected_seq_groups: List[SequenceGroupToSample], + logprobs: torch.Tensor, +) -> SampleResultType: + """Run beam sampling on a given samples. + + Args: + selected_seq_groups: A list of sequence groups batched. + logprobs: (num_selected_samples, vocab_size,) A tensor of logprob + on selected sample indices. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + """ + # We sample 2 * beam_width candidates to make sure that with high + # probability we can get `beam_width` candidates in addition to + # the finished sequences for the next iteration. See + # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 + # for details. See also HF reference: + # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 + # + # NOTE: Beam search is not vectorized, so its speed can be slower than + # other sampling methods. + sample_idx = 0 + results: SampleResultType = [] + for seq_group in selected_seq_groups: + if not seq_group.do_sample: + results.append(([], [])) + continue + + is_prompt = seq_group.is_prompt + seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params + num_parent_seqs = len(seq_ids) + beam_width = sampling_params.n + seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] + if is_prompt: + # Prompt phase. + assert num_parent_seqs == 1, ( + "Prompt input should have only one seq.") + parent_ids = [0] * (2 * beam_width) + _, next_token_ids = torch.topk(seq_group_logprobs[0], + 2 * beam_width) + next_token_ids = next_token_ids.tolist() + else: + # Generation phase. + cumulative_logprobs: List[float] = [ + seq_group.seq_data[seq_id].cumulative_logprob + for seq_id in seq_ids + ] + cumulative_logprobs_tensor = torch.tensor( + cumulative_logprobs, + dtype=torch.float, + device=seq_group_logprobs.device) + seq_group_logprobs = (seq_group_logprobs + + cumulative_logprobs_tensor.unsqueeze(dim=1)) + _, topk_ids = torch.topk(seq_group_logprobs.flatten(), + 2 * beam_width) + topk_ids = topk_ids.tolist() + vocab_size = seq_group_logprobs.size(-1) + parent_ids = [i // vocab_size for i in topk_ids] + next_token_ids = [i % vocab_size for i in topk_ids] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + assert sample_idx == logprobs.size(0) + return results + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, + seq_groups: Optional[List[SequenceGroupToSample]] = None, +) -> torch.Tensor: + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + q = torch.empty_like(probs) + if seq_groups is None: + q.exponential_() + else: + sample_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + q[sample_idx:sample_idx + + stride].exponential_(generator=seq_group.generator) + sample_idx += stride + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + +def _top_k_top_p_multinomial_with_flashinfer( + probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, + num_samples: int, seq_groups: Optional[List[SequenceGroupToSample]]): + max_top_k_round = 32 + if num_samples > 1: + probs = probs.repeat_interleave(num_samples, dim=0) + top_ks = top_ks.repeat_interleave(num_samples) + top_ps = top_ps.repeat_interleave(num_samples) + batch_size = probs.shape[0] + uniform_samples = torch.empty((max_top_k_round, batch_size), + device=probs.device) + if seq_groups is None: + uniform_samples.uniform_() + else: + sample_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group.seq_ids + stride = len(seq_ids) * num_samples + assert seq_group.generator is not None + uniform_samples[:, sample_idx:sample_idx + + stride].uniform_(generator=seq_group.generator) + sample_idx += stride + batch_next_token_ids, success = flashinfer_top_k_top_p_sampling( + probs, + uniform_samples, + top_ks, + top_ps, + ) + if not success.all(): + warnings.warn("FlashInfer rejection sampling failed, fallback.", + stacklevel=1) + probs = flashinfer.sampling.top_k_renorm_prob(probs, top_ks) + probs = flashinfer.sampling.top_p_renorm_prob(probs, top_ps) + batch_next_token_ids = flashinfer.sampling.sampling_from_probs( + probs, uniform_samples[0]) + return batch_next_token_ids.view(-1, num_samples) + + +def get_pythonized_sample_results( + sample_result_args: SampleResultArgsType) -> SampleResultType: + '''This function consumes GPU-side sampler results and computes + Pythonized CPU-side sampler results (GPU -> CPU sync.) + + Single-step scheduling: this function is invoked at sampling-time + for immediate Pythonization. + + Multi-step scheduling: Pythonization is deferred until after multiple + GPU-side steps have been completed. + + Args: + sample_result_args: GPU-side inputs to the Pythonization process + + Returns: + Pythonized sampler results + ''' + + ( + sample_metadata, + sampling_metadata, + greedy_samples, + multinomial_samples, + beam_search_logprobs, + sample_results_dict, + ) = ( + sample_result_args.sample_metadata, + sample_result_args.sampling_metadata, + sample_result_args.greedy_samples, + sample_result_args.multinomial_samples, + sample_result_args.beam_search_logprobs, + sample_result_args.sample_results_dict, + ) + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, + multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + return [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + + +def _sample_with_torch( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + sampling_params = seq_group.sampling_params + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + beam_search_logprobs: Optional[torch.Tensor] = None + + # Create output tensor for sampled token ids. + if include_gpu_probs_tensor: + sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1), + VLLM_INVALID_TOKEN_ID, + dtype=torch.long, + device=logprobs.device) + else: + sampled_token_ids_tensor = None + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + + seq_group_id = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] + sample_metadata[sampling_type] = (seq_group_id, seq_groups) + long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: + greedy_samples = torch.argmax(logprobs[long_sample_indices], + dim=-1) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[ + long_sample_indices] = greedy_samples.unsqueeze(-1) + + if modify_greedy_probs: + # If required, modify the probabilities such that sampling from + # the modified distribution would always sample the argmax + # token id. + _modify_greedy_probs_inplace(logprobs, probs, + long_sample_indices, + greedy_samples) + + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + max_n_in_batch = 1 + for seq_group in seq_groups: + if seq_group.is_prompt: + sampling_params = seq_group.sampling_params + max_n_in_batch = max(max_n_in_batch, sampling_params.n) + seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else + seq_groups) + + if flashinfer_top_k_top_p_sampling is not None: + multinomial_samples[ + sampling_type] = _top_k_top_p_multinomial_with_flashinfer( + probs[long_sample_indices], + sampling_tensors.top_ks[long_sample_indices], + sampling_tensors.top_ps[long_sample_indices], + max_n_in_batch, + seq_groups_arg, + ) + else: + multinomial_samples[sampling_type] = _multinomial( + probs[long_sample_indices], + max_n_in_batch, + seq_groups=seq_groups_arg) + + if sampled_token_ids_tensor is not None: + # Store sampled tokens in output tensor. + sampled_token_ids_tensor[long_sample_indices] = \ + multinomial_samples[sampling_type].to(torch.long) + + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + beam_search_logprobs=beam_search_logprobs, + sample_results_dict=sample_results_dict) + + if not sampling_metadata.skip_sampler_cpu_output: + # GPU<->CPU sync happens here. + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor + else: + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sampling_tensors: SamplingTensors, + include_gpu_probs_tensor: bool, + modify_greedy_probs: bool, +) -> SampleReturnType: + """ + Args: + probs: (num_query_tokens_in_batch, num_vocab) + logprobs: (num_query_tokens_in_batch, num_vocab) + sampling_metadata: The metadata for a batch for sampling. + sampling_tensors: Tensors that include sampling related metadata. + + Returns: + (next_token_ids, parent_seq_ids) for each seq group in a batch. + If sampling is skipped, it returns ([], []) + sampled_token_ids_tensor: A tensor of sampled token ids. + """ + return _sample_with_torch( + probs, + logprobs, + sampling_metadata, + sampling_tensors, + include_gpu_probs_tensor=include_gpu_probs_tensor, + modify_greedy_probs=modify_greedy_probs, + ) + + +def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + """ + This function calculates the ranks of the chosen tokens in a logprob tensor. + + Args: + x (torch.Tensor): 2D logprob tensor of shape (N, M) + where N is the no. of tokens and M is the vocab dim. + indices (torch.Tensor): List of chosen token indices. + + Returns: + torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. + Each element in the returned tensor represents the rank + of the chosen token in the input logprob tensor. + """ + vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), + indices] + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) + + +def get_logprobs( + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sample_results: SampleResultType, +) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: + """Return sample logprobs and prompt logprobs. + + The logic consists of 3 parts. + - Select indices to compute logprob from, ranks of token ids, and + the top k token ids from logprobs. + - Compute prompt logprobs if required. + - Compute sample logprobs if required. + + Args: + logprobs: (num_query_tokens_across_batch, num_vocab). Each query token's + logprob per vocab. Sequence groups' query tokens are batched in a + single flattened tensor. For example, assuming there are N + seq groups, it is sorted by prefill tokens for seq_group_1 (if + prompt logprob is enabled), decode tokens for seq_group_1 (if + sampling is required), prefill tokens for seq_group_2, ... + sampling_metadata: The sampling metadata. + sample_results: (num_seq_groups) The tuple of (next_token_ids, + parent_ids) for each sequence group. When beam search is enabled, + sample_results can contain different number of seq_ids from + sampling_metadata.seq_groups. It is because beam search creates + 2 * BEAM_WIDTH number of samples (whereas there are only up to + BEAM_WIDTH number of seq_ids). + + Returns: + A tuple of prompt and sample logprobs per sequence group in a batch. + """ + # The index of query token to calculate logprobs. It includes both + # prompt and sample logprob indices. + query_indices: List[int] = [] + # The next token ids to get the logprob value from. + next_token_ids: List[int] = [] + # The largest requested number of logprobs. We find logprobs as many as the + # largest num logprobs in this API. If every logprobs is None, it will be + # set to -1. + largest_num_logprobs = -1 + + # Select indices to compute logprob from, ranks of token ids, and the top + # k token ids from logprobs. + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + sample_results): + sampling_params = seq_group.sampling_params + + # Update indices and tokens for prompt logprobs. + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.prompt_logprobs) + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + query_indices.extend(seq_group.prompt_logprob_indices) + next_token_ids.extend(next_prompt_tokens) + + # Update indices and next tokenes for sample logprob. + if seq_group.do_sample: + token_ids, parent_seq_ids = sample_result + # NOTE: We cannot directly use sample_indices because + # sample_indices only contain parent seq_ids of a previous step. + # The current step may have different number of seq_ids, and + # we can obtain it from `sample_result[1]`. + query_idx = seq_group.sample_indices[0] + query_indices.extend( + [query_idx + parent_id for parent_id in parent_seq_ids]) + next_token_ids.extend(token_ids) + + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + + assert len(next_token_ids) == len(query_indices) + + if len(query_indices) == 0: + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None + return [empty_prompt_logprob], [empty_sampled_logprob] + + selected_logprobs, ranks = None, None + top_logprobs, top_token_ids = None, None + + # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can + # skip the whole logprob calculation. + if largest_num_logprobs >= 0: + query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) + next_token_ids_gpu = torch.tensor(next_token_ids, + device=logprobs.device) + + # (num_selected_query_tokens, num_logprobs). Note that query_indices can + # contain duplicates if beam search is enabled. + selected_logprobs = logprobs[[ + query_indices_gpu, + next_token_ids_gpu, + ]] + ranks = _get_ranks( + logprobs[query_indices_gpu], + next_token_ids_gpu, + ) + assert selected_logprobs.shape[0] == ranks.shape[0] + + # We need to compute top k only if there exists logprobs > 0. + if largest_num_logprobs > 0: + # Logprobs of topk tokens for a batch of sequence groups. + # (num_query_tokens_across_batch). + top_logprobs, top_token_ids = torch.topk(logprobs, + largest_num_logprobs, + dim=-1) + top_logprobs = top_logprobs.to('cpu') + top_token_ids = top_token_ids.to('cpu') + + selected_logprobs = selected_logprobs.to('cpu') + ranks = ranks.to('cpu') + + # Find prompt/sample logprobs. + prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = [] + sample_logprobs_per_seq_group: List[SampleLogprobs] = [] + top_logprob_idx = 0 + selected_logprobs_idx = 0 + + for seq_group, sample_result in zip(sampling_metadata.seq_groups, + sample_results): + (prompt_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_prompt_logprob_if_needed( + seq_group, selected_logprobs, ranks, top_token_ids, top_logprobs, + selected_logprobs_idx, top_logprob_idx) + prompt_logprobs_per_seq_group.append(prompt_logprobs) + + (sampled_logprobs, top_logprob_idx, + selected_logprobs_idx) = _get_sampled_logprob_if_needed( + seq_group, sample_result, selected_logprobs, ranks, top_token_ids, + top_logprobs, selected_logprobs_idx, top_logprob_idx) + sample_logprobs_per_seq_group.append(sampled_logprobs) + + return prompt_logprobs_per_seq_group, sample_logprobs_per_seq_group + + +def _get_prompt_logprob_if_needed( + seq_group: SequenceGroupToSample, + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the prompt logprob from a sequence group if needed.""" + sampling_params = seq_group.sampling_params + is_prompt = seq_group.is_prompt + + # Find prompt logprobs + prompt_logprobs: Optional[PromptLogprobs] = None + if is_prompt and sampling_params.prompt_logprobs is not None: + prompt_logprobs = [] + num_logprobs = sampling_params.prompt_logprobs + next_prompt_tokens = _get_next_prompt_tokens(seq_group) + # Pre-select indexes and create a list. It is faster than calling .item + # repetitively. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_prompt_tokens)].tolist() + + for idx, token_id in enumerate(next_prompt_tokens): + # Calculate the prompt logprob of the real prompt tokens. + # {token_id: (logprob, rank_from_vocab)} + prompt_logprobs_dict: Dict[int, Tuple[float, int]] = { + token_id: (selected_logprob_items[idx], rank_items[idx]) + } + + # Add top K prompt logprobs along with its rank. + if num_logprobs > 0: + top_ids = top_token_ids[ + top_logprob_idx, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + prompt_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip(top_ids, top_probs, + top_ranks) + }) + prompt_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in prompt_logprobs_dict.items() + }) + # + 1 to go to the next prompt token. + top_logprob_idx += 1 + + # + len(next_prompt_tokens) to go to the next prompt. + selected_logprobs_idx += len(next_prompt_tokens) + return prompt_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _get_sampled_logprob_if_needed( + seq_group: SequenceGroupToSample, + sample_result: Tuple[List[int], List[int]], + selected_logprobs: torch.Tensor, + ranks: torch.Tensor, + top_token_ids: torch.Tensor, + top_logprobs: torch.Tensor, + selected_logprobs_idx: int, + top_logprob_idx: int, +): + """Compute the sample logprob if needed.""" + seq_ids = seq_group.seq_ids + num_logprobs = seq_group.sampling_params.logprobs + sampled_logprobs: SampleLogprobs = [] + next_token_ids, parent_seq_ids = sample_result + + if seq_group.do_sample: + assert len(next_token_ids) > 0 + if num_logprobs is None: + for next_token_id in next_token_ids: + # Use a dummy logprob + sampled_logprobs.append({next_token_id: Logprob(inf)}) + else: + # Pre-select items from tensor. tolist() is faster than repetitive + # `.item()` calls. + selected_logprob_items = selected_logprobs[ + selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx + + len(next_token_ids)].tolist() + for idx, (next_token_id, parent_id) in enumerate( + zip(next_token_ids, parent_seq_ids)): + # Get the logprob of a sampled token. + sampled_logprobs_dict = { + next_token_id: + (selected_logprob_items[idx], rank_items[idx]) + } + if num_logprobs is not None and num_logprobs > 0: + # Get top K logprobs. + top_ids = top_token_ids[top_logprob_idx + + parent_id, :num_logprobs].tolist() + top_probs = top_logprobs[ + top_logprob_idx + parent_id, :num_logprobs].tolist() + # Top K is already sorted by rank, so we can use 1 ~ + # num_logprobs + 1 for rank. + top_ranks = range(1, num_logprobs + 1) + sampled_logprobs_dict.update({ + top_id: (top_prob, rank) + for top_id, top_prob, rank in zip( + top_ids, top_probs, top_ranks) + }) + + sampled_logprobs.append({ + token_id: Logprob(*logprob_and_rank) + for token_id, logprob_and_rank in + sampled_logprobs_dict.items() + }) + + # NOTE: This part of code is not intuitive. `selected_logprobs` include + # logprobs for the current step, which has len(next_token_ids) tokens + # per sequence group. `logprobs` includes logprobs from the previous + # steps, which has len(seq_ids) tokens per sequence group. + + # Iterate to the next sequence group in a batch. + selected_logprobs_idx += len(next_token_ids) + # Iterate to the next sequence group in a batch. + top_logprob_idx += len(seq_ids) + return sampled_logprobs, top_logprob_idx, selected_logprobs_idx + + +def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, + sample_indices: torch.Tensor, + greedy_samples: torch.Tensor) -> None: + """Modify the probability distributions of the greedily-sampled tokens such + that each sampled token has a "probability" of 1.0. This is required by + speculative decoding, which depends on the sampling method being encoded + within the probability distribution for correctness. + + # Why do we only need to do this for greedy sampling? + + vLLM's sampler performs the following steps for greedy or multinomial + (random) sampling: + 1. Get logits from model. + 2. Modify logits according to per-sequence sampling parameters. + - Multiply by temperature, top-k and top-p masking, penalize tokens + according to their frequency, etc. + 3. Sample a token. + - Random sampling simply samples from the modified probability + distribution. + - Greedy sampling performs `argmax` to obtain the token with the + highest likelihood. + + Ignoring greedy sampling for a moment, we find that the computed probability + distribution has the following property: we can sample from it independently + and find that the token sampled by the Sampler has a frequency corresponding + to how often we see it in our sampling. In other words, for tokens sampled + with vLLM's random SamplingType, the computed probability distribution + encodes the sampling methodology completely. + + Greedy sampling does not normally have this property. vLLM modifies logits + according to sampling params, then performs `argmax`, then returns the + sampled token and the computed probability distribution. If we sample from + the distribution, we'll find the likelihood of the greedily-sampled token + is not always 1.0. + + Since lossless speculative decoding requires that the sampling methodology + be encoded within the probability distribution, we are motivated to modify + the probability distribution such that the sampled token has probability 1 + when speculative decoding is used. + + NOTE: Alternatively, we could use an extremely low temperature to achieve + greedy sampling using multinomial computation and unite the codepaths. This + has implications on the overall design of the sampler, e.g. how to record + accurate logprobs for the user, so this improvement is deferred to later. + """ + # NOTE: logprobs are not modified so they can be returned to the user. + probs[sample_indices, :] = 0 + probs[sample_indices, greedy_samples] = 1.0 + + +def _build_sampler_output( + maybe_deferred_sample_results: MaybeDeferredSampleResultType, + sampling_metadata: SamplingMetadata, + prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], + sample_logprobs: Optional[List[SampleLogprobs]], + on_device_tensors: Optional[Tuple[torch.Tensor, torch.Tensor, + torch.Tensor]], + skip_sampler_cpu_output: bool = False, +) -> SamplerOutput: + """Construct Python objects with the output of sampling. + + Args: + on_device_tensors: Tuple containing on-device tensors with the + probabilities used in sampling and the sampled token ids. This + allows post-processing without copies to CPU/serialization, e.g. in + speculative decoding rejection sampling. + """ + sampler_output: List[CompletionSequenceGroupOutput] = [] + + if skip_sampler_cpu_output: + assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) + deferred_sample_results_args = maybe_deferred_sample_results + else: + assert prompt_logprobs is not None + assert sample_logprobs is not None + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + deferred_sample_results_args = None + + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + maybe_deferred_sample_results, + prompt_logprobs, sample_logprobs): + seq_ids = seq_group.seq_ids + next_token_ids, parent_ids = sample_result + seq_outputs: List[SequenceOutput] = [] + for parent_id, next_token_id, logprobs in zip( + parent_ids, next_token_ids, group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + logprobs)) + sampler_output.append( + CompletionSequenceGroupOutput(seq_outputs, + group_prompt_logprobs)) + + # If not specified, store None values in SamplerOutput. + if on_device_tensors is not None: + (sampled_token_probs, logprobs_tensor, + sampled_token_ids) = on_device_tensors + else: + sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, + None) + + return SamplerOutput( + outputs=sampler_output, + sampled_token_probs=sampled_token_probs, + sampled_token_ids=sampled_token_ids, + logprobs=logprobs_tensor, + deferred_sample_results_args=deferred_sample_results_args) + + +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: + """Get a list of next prompt tokens to compute logprob from a + given sequence group. + + It is used to compute prompt logprob. Imagine you have logprob for each + query token. Query token needs to know the next prompt token id to compute + prompt logprob. This is a helper to obtain next prompt token ids. + + This API has to be used only when the caller knows seq_group is in prefill + stage. + + Returns: + A list of next prompt tokens to compute logprob. + """ + assert seq_group.is_prompt, ( + "Caller should ensure the sequence group is in a prefill stage.") + seq_ids = seq_group.seq_ids + query_len = seq_group.query_len + assert query_len is not None + # prompt has only 1 seq id. + assert len(seq_ids) == 1 + seq_data = seq_group.seq_data[seq_ids[0]] + computed_len = seq_data.get_num_computed_tokens() + prompt_tokens = seq_data.prompt_token_ids + # +1 because we are looking for a next prompt token. + next_token_index_start = computed_len + 1 + next_token_index_end = min(computed_len + query_len + 1, + len(prompt_tokens)) + next_prompt_tokens = prompt_tokens[ + next_token_index_start:next_token_index_end] + return next_prompt_tokens diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py new file mode 100644 index 0000000..7e750a7 --- /dev/null +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -0,0 +1,239 @@ +from abc import abstractmethod +from typing import Dict, Optional, Union + +import torch +import torch.jit +import torch.nn as nn + + +class SpecDecodeBaseSampler(nn.Module): + """Base class for samplers used for Speculative Decoding verification + step. + """ + + def __init__(self, strict_mode: bool = False): + """Base class constructor. + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, device: Union[int, str]) -> None: + assert self.num_accepted_tokens is None + if isinstance(device, int): + device = f"cuda:{device}" + elif not isinstance(device, str): + raise ValueError(f"Device must be int or str, get {type(device)}") + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + @property + def probs_dtype(self): + return torch.float32 + + @property + def token_id_dtype(self): + return torch.int64 + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + substitute_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via sampling, all subsequent token ids are + set to -1 for the sequence. + + Args: + accepted: A boolean tensor indicating if the corresponding + draft token in draft_token_ids should be accepted or not. + substitute_token_ids: A tensor of token_ids that can be used + as substitutes for the draft token ids if the proposed token + is rejected. + draft_token_ids: A tensor of token ids speculated by the + draft model. + bonus_token_ids: Token ids to use as the bonus token if + all the draft tokens are accepted. + Returns: + A tensor containing the accepted token ids. The shape of the + tensor is [batch_size, k + num_bonus_tokens] + """ + batch_size, k = substitute_token_ids.shape + bonus_token_ids = bonus_token_ids.squeeze() + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + substitute_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_input( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + self._raise_if_incorrect_shape(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_incorrect_dtype(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_inconsistent_device(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], + draft_token_ids, bonus_token_ids) + + def _raise_if_incorrect_shape( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_with_bonus_probs.shape + + # Does not count the extra token + num_target_probs -= 1 + + # validate the shape of draft token ids. + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + assert draft_token_ids_batch_size == target_batch_size + assert num_draft_token_ids == num_target_probs + + # validate the shape of bonus token ids + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + # validate the shape of draft probs if it is set + if draft_probs is not None: + (draft_batch_size, num_draft_probs, + draft_vocab_size) = draft_probs.shape + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + def _raise_if_incorrect_dtype( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + assert target_with_bonus_probs.dtype == self.probs_dtype + assert draft_token_ids.dtype == self.token_id_dtype + assert bonus_token_ids.dtype == self.token_id_dtype + if draft_probs is not None: + assert draft_probs.dtype == self.probs_dtype + + def _raise_if_inconsistent_device( + self, + target_with_bonus_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: Optional[torch.Tensor] = None, + ) -> None: + devices = [ + t.device for t in [ + target_with_bonus_probs, bonus_token_ids, draft_probs, + draft_token_ids + ] if t is not None + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + draft_token_ids: torch.Tensor, + bonus_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + + +class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are deterministic. + """ + + @abstractmethod + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): + """Base class for samplers used for Speculative Decoding verification + step which are stochastic + """ + + @abstractmethod + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + seeded_seqs: Optional[Dict[int, torch.Generator]] = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py new file mode 100644 index 0000000..584cf97 --- /dev/null +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -0,0 +1,170 @@ +import torch +import torch.jit + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeDeterministicBaseSampler) + + +class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler): + """Apply typical acceptance sampling as described in section 3.3.1 in + "MEDUSA: Simple LLM Inference Acceleration Framework with + Multiple Decoding Heads" + https://arxiv.org/pdf/2401.10774 + """ + + def __init__( + self, + posterior_threshold: float, + posterior_alpha: float, + strict_mode: bool = False, + ): + """Create a Typical Acceptance Sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + posterior_threshold : A threshold value that sets a lower bound + on the posterior probability of a token in target model for it + to be accepted. + posterior_alpha : A scaling factor for the entropy-based + threshold in typical acceptance sampling. + """ + self._posterior_threshold = posterior_threshold + self._posterior_alpha = posterior_alpha + super().__init__(strict_mode=strict_mode) + + def forward( + self, + target_with_bonus_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using typical acceptance sampling. This accepts + or rejects tokens proposed by the draft model using the probability + of each token according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one token will be emitted. + + In the case where all draft tokens are accepted, the bonus token will be + accepted. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: This parameter is unused by the acceptance sampler. + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids) + target_probs = target_with_bonus_probs[:, :-1] + accepted = self._evaluate_accepted_tokens(target_probs, + draft_token_ids) + recovered_token_ids = self._get_recovered_token_ids(target_probs) + output_token_ids = self._create_output(accepted, recovered_token_ids, + draft_token_ids, + bonus_token_ids) + return output_token_ids + + def _evaluate_accepted_tokens(self, target_probs, draft_token_ids): + r""" + Evaluates and returns a mask of accepted tokens based on the + posterior probabilities. + + Parameters: + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) representing + the probabilities of each token in the vocabulary for each + position in the proposed sequence. This is the distribution + generated by the target model. + draft_token_ids : torch.Tensor + A tensor of shape (batch_size, k) representing the proposed + token ids. + + A draft token_id x_{n+k} is accepted if it satisfies the + following condition + + .. math:: + p_{\text{original}}(x_{n+k} | x_1, x_2, \dots, x_{n+k-1}) > + \min \left( \epsilon, \delta * \exp \left( + -H(p_{\text{original}}( + \cdot | x_1, x_2, \ldots, x_{n+k-1})) \right) \right) + + where :math:`p_{\text{original}}` corresponds to target_probs + and :math:`\epsilon` and :math:`\delta` correspond to hyperparameters + specified using self._posterior_threshold and self._posterior_alpha + + This method computes the posterior probabilities for the given + draft token ids based on the provided target probabilities. It + calculates the entropy of the posterior distribution and determines + a dynamic threshold for each token position using the provided + posterior_threshold and posterior_alpha values. The method then + returns a boolean mask indicating which tokens can be accepted. + + Returns: + ------- + torch.Tensor + A boolean tensor of shape (batch_size, k) where each element + indicates whether the corresponding draft token has been accepted + or rejected. True indicates acceptance and false indicates + rejection. + + """ + device = target_probs.device + candidates_prob = torch.gather( + target_probs, dim=-1, + index=draft_token_ids.unsqueeze(-1)).squeeze(-1) + # A small constant added to prevent computing the logarithm of zero, + # which can lead to undefined values. + epsilon = 1e-5 + posterior_entropy = -torch.sum( + target_probs * torch.log(target_probs + epsilon), dim=-1) + threshold = torch.minimum( + torch.ones_like(posterior_entropy, device=device) * + self._posterior_threshold, + torch.exp(-posterior_entropy) * self._posterior_alpha, + ) + accepted_mask = candidates_prob > threshold + return accepted_mask + + def _get_recovered_token_ids(self, target_probs): + """ + The recovered token ids will fill the first unmatched token + by the target token. + + Parameters + ---------- + target_probs : torch.Tensor + A tensor of shape (batch_size, k, vocab_size) containing + the target probability distribution + + Returns + ------- + torch.Tensor + A tensor of shape (batch_size, k) with the recovered token + ids which are selected from target probs. + """ + max_indices = torch.argmax(target_probs, dim=-1) + + return max_indices diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000..dff7a63 --- /dev/null +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -0,0 +1,471 @@ +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn.functional as F +# import ixformer.inference.functions as IXF +import ixformer.functions as IXF +from torch.nn.parameter import Parameter, UninitializedParameter + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.parameter import BasevLLMParameter +from vllm.model_executor.utils import set_weight_attrs + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +class UnquantizedEmbeddingMethod(QuantizeMethodBase): + """Unquantized method for embeddings.""" + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """Create weights for embedding layer.""" + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return IXF.linear(x, layer.weight, bias) + + def embedding(self, layer: torch.nn.Module, + input_: torch.Tensor) -> torch.Tensor: + return F.embedding(input_, layer.weight) + + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, + rank: int, + offset: int = 0) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f + offset, index_l + offset + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, + rank: int, + world_size: int, + offset: int = 0) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank, + offset=offset) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return (self.padded_org_vocab_end_index - + self.padded_org_vocab_start_index) + + @property + def num_added_elements_padded(self) -> int: + return (self.padded_added_vocab_end_index - + self.padded_added_vocab_start_index) + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert (self.padded_org_vocab_start_index <= + self.padded_org_vocab_end_index) + assert (self.padded_added_vocab_start_index <= + self.padded_added_vocab_end_index) + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert (self.added_vocab_start_index <= + self.padded_added_vocab_start_index) + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded + + +# @torch.jit.script +def get_masked_input_and_mask( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.jit.script will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < + org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + quant_config: quant config for the layer + prefix: full name of the layer in the state dict + """ # noqa: E501 + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + # Keep the input dimensions. + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_embeddings = num_embeddings + self.padding_size = padding_size + self.org_vocab_size = org_num_embeddings or num_embeddings + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + self.embedding_dim = embedding_dim + + linear_method = None + if quant_config is not None: + linear_method = quant_config.get_quant_method(self, prefix=prefix) + if linear_method is None: + linear_method = UnquantizedEmbeddingMethod() + + # If we are making an embedding layer, then our quantization linear + # method must implement the embedding operation. If we are another + # layer type like ParallelLMHead, this is not important. + is_embedding_layer = type(self.__class__) is VocabParallelEmbedding + linear_method_implements_embedding = method_has_implemented_embedding( + type(linear_method)) + if is_embedding_layer and not linear_method_implements_embedding: + raise NotImplementedError( + f"The class {type(linear_method).__name__} must implement " + "the 'embedding' method, see UnquantizedEmbeddingMethod.") + + self.linear_method: QuantizeMethodBase = linear_method + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + # Divide the weight matrix along the vocaburaly dimension. + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) + + self.linear_method.create_weights(self, + self.embedding_dim, + [self.num_embeddings_per_partition], + self.embedding_dim, + self.num_embeddings_padded, + params_dtype=params_dtype, + weight_loader=self.weight_loader) + + @classmethod + def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, + vocab_size: int, org_vocab_size: int, tp_rank: int, + tp_size: int) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = ( + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, + tp_size)) + padded_added_vocab_start_index, padded_added_vocab_end_index = ( + vocab_range_from_global_vocab_size(num_added_embeddings_padded, + tp_rank, + tp_size, + offset=org_vocab_size)) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, + org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, + vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, padded_org_vocab_end_index, + padded_added_vocab_start_index, padded_added_vocab_end_index, + org_vocab_start_index, org_vocab_end_index, + added_vocab_start_index, added_vocab_end_index) + + def get_sharded_to_full_mapping(self) -> Optional[List[int]]: + """Get a mapping that can be used to reindex the gathered + logits for sampling. + + During sampling, we gather logits from all ranks. The relationship + of index->token_id will follow the same format as outlined in the class + docstring. However, after the gather, we want to reindex the final + logits tensor to map index->token_id one-to-one (the index is always + equal the token_id it corresponds to). The indices returned by this + method allow us to do that. + """ + if self.tp_size < 2: + return None + + base_embeddings: List[int] = [] + added_embeddings: List[int] = [] + padding: List[int] = [] + for tp_rank in range(self.tp_size): + shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + range_start = self.num_embeddings_per_partition * tp_rank + range_end = self.num_embeddings_per_partition * (tp_rank + 1) + base_embeddings.extend( + range(range_start, + range_start + shard_indices.num_org_elements)) + padding.extend( + range(range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded)) + added_embeddings.extend( + range( + range_start + shard_indices.num_org_elements_padded, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements)) + padding.extend( + range( + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded)) + assert (range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded == range_end) + ret = base_embeddings + added_embeddings + padding + assert len(ret) == self.num_embeddings_padded + return ret + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + output_dim = getattr(param, "output_dim", None) + packed_dim = getattr(param, "packed_dim", None) + + # If the parameter is a gguf weight, then load it directly. + if getattr(param, "is_gguf_weight_type", None): + param.data.copy_(loaded_weight) + param.weight_type = loaded_weight.item() + return + elif isinstance(param, UninitializedParameter): + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = shape[output_dim] // self.tp_size + param.materialize(tuple(shape), dtype=loaded_weight.dtype) + + # If parameter does not have output dim, then it should + # be copied onto all gpus (e.g. g_idx for act_order gptq). + if output_dim is None: + assert param.data.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + return + + # Shard indexes for loading the weight + start_idx = self.shard_indices.org_vocab_start_index + shard_size = self.shard_indices.org_vocab_end_index - start_idx + + # If param packed on the same dim we are sharding on, then + # need to adjust offsets of loaded weight by pack_factor. + if packed_dim is not None and packed_dim == output_dim: + packed_factor = param.packed_factor if isinstance( + param, BasevLLMParameter) else param.pack_factor + assert loaded_weight.shape[output_dim] == (self.org_vocab_size // + param.packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor + else: + assert loaded_weight.shape[output_dim] == self.org_vocab_size + + # Copy the data. + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + masked_input, input_mask = get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) + else: + masked_input = input_ + # Get the embeddings. + output_parallel = self.linear_method.embedding(self, + masked_input.long()) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0) + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + def extra_repr(self) -> str: + s = f"num_embeddings={self.num_embeddings_per_partition}" + s += f", embedding_dim={self.embedding_dim}" + s += f", org_vocab_size={self.org_vocab_size}" + s += f', num_embeddings_padded={self.num_embeddings_padded}' + s += f', tp_size={self.tp_size}' + return s + + +class ParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size, quant_config, + prefix) + self.quant_config = quant_config + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def tie_weights(self, embed_tokens: VocabParallelEmbedding): + """Tie the weights with word embeddings.""" + # GGUF quantized embed_tokens. + if self.quant_config and self.quant_config.get_name() == "gguf": + return embed_tokens + else: + self.weight = embed_tokens.weight + return self + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/model_loader/__init__.py b/vllm/model_executor/model_loader/__init__.py new file mode 100644 index 0000000..d1ec171 --- /dev/null +++ b/vllm/model_executor/model_loader/__init__.py @@ -0,0 +1,30 @@ +from typing import Optional + +from torch import nn + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.model_executor.model_loader.loader import (BaseModelLoader, + get_model_loader) +from vllm.model_executor.model_loader.utils import ( + get_architecture_class_name, get_model_architecture) + + +def get_model(*, model_config: ModelConfig, load_config: LoadConfig, + device_config: DeviceConfig, parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig) -> nn.Module: + loader = get_model_loader(load_config) + return loader.load_model(model_config=model_config, + device_config=device_config, + lora_config=lora_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + cache_config=cache_config) + + +__all__ = [ + "get_model", "get_model_loader", "BaseModelLoader", + "get_architecture_class_name", "get_model_architecture" +] diff --git a/vllm/model_executor/model_loader/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..3a5a9b4 Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/__pycache__/loader.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/loader.cpython-310.pyc new file mode 100644 index 0000000..82e3897 Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/loader.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/__pycache__/neuron.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/neuron.cpython-310.pyc new file mode 100644 index 0000000..c6314a9 Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/neuron.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/__pycache__/openvino.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/openvino.cpython-310.pyc new file mode 100644 index 0000000..d2e597f Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/openvino.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-310.pyc new file mode 100644 index 0000000..3bd15bb Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/tensorizer.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/__pycache__/utils.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..059d9cd Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-310.pyc b/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-310.pyc new file mode 100644 index 0000000..ccaa569 Binary files /dev/null and b/vllm/model_executor/model_loader/__pycache__/weight_utils.cpython-310.pyc differ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py new file mode 100644 index 0000000..813f583 --- /dev/null +++ b/vllm/model_executor/model_loader/loader.py @@ -0,0 +1,1259 @@ +# ruff: noqa: SIM117 +import collections +import copy +import dataclasses +import fnmatch +import glob +import json +import math +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import (Any, Dict, Generator, Iterable, List, Optional, Tuple, + Type, cast) + +import gguf +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi, hf_hub_download +from torch import nn +from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, MultiModalConfig, + ParallelConfig, SchedulerConfig) +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.tensorizer import ( + TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, + serialize_vllm_model, tensorizer_weights_iterator) +from vllm.model_executor.model_loader.utils import (get_model_architecture, + set_default_torch_dtype) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, + initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, + safetensors_weights_iterator) +from vllm.model_executor.models import (has_inner_state, supports_lora, + supports_multimodal) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import is_pin_memory_available + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = init_logger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + +def _get_model_initialization_kwargs( + model_class: Type[nn.Module], + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig] = None) -> Dict[str, Any]: + """Get extra kwargs for model initialization.""" + extra_kwargs: Dict[str, Any] = {} + + if supports_lora(model_class): + # lora_config=None is used to disable LoRA + extra_kwargs["lora_config"] = lora_config + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + + if supports_multimodal(model_class): + assert multimodal_config is not None + + extra_kwargs["multimodal_config"] = multimodal_config + + if has_inner_state(model_class) and scheduler_config: + extra_kwargs["scheduler_config"] = scheduler_config + + return extra_kwargs + + +def build_model(model_class: Type[nn.Module], hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], *, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + scheduler_config: Optional[SchedulerConfig]) -> nn.Module: + extra_kwargs = _get_model_initialization_kwargs(model_class, lora_config, + multimodal_config, + scheduler_config) + + return model_class(config=hf_config, + cache_config=cache_config, + quant_config=quant_config, + **extra_kwargs) + + +def _initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig, + scheduler_config: Optional[SchedulerConfig] = None) -> nn.Module: + """Initialize a model with the given configurations.""" + model_class, _ = get_model_architecture(model_config) + + return build_model( + model_class, + model_config.hf_config, + cache_config=cache_config, + quant_config=_get_quantization_config(model_config, load_config), + lora_config=lora_config, + multimodal_config=model_config.multimodal_config, + scheduler_config=scheduler_config, + ) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + + @abstractmethod + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + """Load a model with the given configurations.""" + raise NotImplementedError + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + @dataclasses.dataclass + class Source: + """A source for weights.""" + + model_or_path: str + """The model ID or path.""" + + revision: Optional[str] + """The optional model revision.""" + + prefix: str = "" + """A prefix to prepend to all weights.""" + + fall_back_to_pt: bool = True + """Whether .pt weights can be used.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, + self.load_config.download_dir, revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, source: "Source" + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + source.model_or_path, source.revision, source.fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + source.model_or_path, self.load_config.download_dir, hf_folder, + hf_weights_files) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + if current_platform.is_tpu(): + # In PyTorch XLA, we should call `xm.mark_step` frequently so that + # not too many ops are accumulated in the XLA program. + import torch_xla.core.xla_model as xm + + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) + + # Apply the prefix. + return ((source.prefix + name, tensor) + for (name, tensor) in weights_iterator) + + def _get_all_weights( + self, + model_config: ModelConfig, + model: nn.Module, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + + primary_weights = DefaultModelLoader.Source( + model_config.model, + model_config.revision, + prefix="", + fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", + True)) + yield from self._get_weights_iterator(primary_weights) + + secondary_weights = cast(Iterable[DefaultModelLoader.Source], + getattr(model, "secondary_weights", ())) + for source in secondary_weights: + yield from self._get_weights_iterator(source) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config, + scheduler_config) + + model.load_weights(self._get_all_weights(model_config, model)) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config, + scheduler_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class TensorizerLoader(BaseModelLoader): + """Model loader using CoreWeave's tensorizer library.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if isinstance(load_config.model_loader_extra_config, TensorizerConfig): + self.tensorizer_config = load_config.model_loader_extra_config + else: + self.tensorizer_config = TensorizerConfig( + **load_config.model_loader_extra_config) + + def _verify_config(self, model_config: ModelConfig, + parallel_config: ParallelConfig): + self.tensorizer_config.verify_with_model_config(model_config) + self.tensorizer_config.verify_with_parallel_config(parallel_config) + + def _get_weights_iterator( + self) -> Generator[Tuple[str, torch.Tensor], None, None]: + tensorizer_args = self.tensorizer_config._construct_tensorizer_args() + return tensorizer_weights_iterator(tensorizer_args) + + def _load_model_serialized_cpu( + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer to the CPU. + + This is only necessary when the model isn't vLLM-tensorized (see + examples/tensorize_vllm_model.py) This should still be faster than + default HuggingFace loading, but will be slower than loading a + vLLM-tensorized model. + """ + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config) + + model.load_weights(self._get_weights_iterator()) + return model.eval() + + def _load_model_serialized( + self, + model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + cache_config: CacheConfig, + ) -> nn.Module: + """Load a serialized model with tensorizer. + + Expects a vLLM-tensorized model. See the + examples/tensorize_vllm_model.py example script + for serializing vLLM models.""" + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config( + model_config, self.load_config) + extra_kwargs = _get_model_initialization_kwargs( + model_class, lora_config, model_config.multimodal_config) + extra_kwargs["quant_config"] = quant_config + extra_kwargs["cache_config"] = cache_config + + tensorizer_config = copy.copy(self.tensorizer_config) + tensorizer_config.model_class = model_class + tensorizer_config.hf_config = model_config.hf_config + tensorizer_config.dtype = model_config.dtype + + model = load_with_tensorizer(tensorizer_config, **extra_kwargs) + return model.eval() + + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + self._verify_config(model_config, parallel_config) + + if parallel_config.tensor_parallel_size > 1: + from vllm.distributed import get_tensor_model_parallel_rank + self.tensorizer_config.tensorizer_uri = \ + self.tensorizer_config.tensorizer_uri \ + % get_tensor_model_parallel_rank() + + if is_vllm_tensorized(self.tensorizer_config): + return self._load_model_serialized(model_config, device_config, + lora_config, cache_config) + return self._load_model_serialized_cpu(model_config, device_config, + lora_config, cache_config) + + @staticmethod + def save_model( + model: torch.nn.Module, + tensorizer_config: TensorizerConfig, + ) -> None: + serialize_vllm_model( + model=model, + tensorizer_config=tensorizer_config, + ) + + +class ShardedStateLoader(BaseModelLoader): + """ + Model loader that directly loads each worker's model state dict, which + enables a fast load path for large tensor-parallel models where each worker + only needs to read its own shard rather than the entire checkpoint. See + `examples/save_sharded_state.py` for creating a sharded checkpoint. + """ + + DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + extra_config = ({} if load_config.model_loader_extra_config is None + else load_config.model_loader_extra_config.copy()) + self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) + if extra_config: + raise ValueError(f"Unexpected extra config keys for load format " + f"{load_config.load_format}: " + f"{load_config.model_loader_extra_config.keys()}") + + @staticmethod + def _filter_subtensors( + tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """ + Filter out all tensors that share the same memory or a subset of the + memory of another tensor. + """ + same_storage_groups: Dict[Any, List[Tuple[ + str, torch.Tensor]]] = collections.defaultdict(list) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break # t2 covers strictly more memory than t. + if k2 < k: + # Same tensors, keep the one with the smaller key. + break + else: + result[k] = t + return result + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]): + if os.path.isdir(model_name_or_path): + return model_name_or_path + else: + allow_patterns = ["*.safetensors"] + return download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + from safetensors.torch import safe_open + + from vllm.distributed import get_tensor_model_parallel_rank + + local_model_path = self._prepare_weights(model_config.model, + model_config.revision) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + rank = get_tensor_model_parallel_rank() + pattern = os.path.join( + local_model_path, + self.pattern.format(rank=rank, part="*"), + ) + filepaths = glob.glob(pattern) + if not filepaths: + # TODO: support un-sharded checkpoints too + raise ValueError( + f"Could not find checkpoint files '{pattern}', only " + f"pre-sharded checkpoints are currently supported!") + state_dict = self._filter_subtensors(model.state_dict()) + for path in filepaths: + with safe_open(path, framework="pt") as f: + for key in f.keys(): # noqa: SIM118 + tensor = f.get_tensor(key) + # If loading with LoRA enabled, additional padding may + # be added to certain parameters. We only load into a + # narrowed view of the parameter data. + param_data = state_dict[key].data + param_shape = state_dict[key].shape + for dim, size in enumerate(tensor.shape): + if size < param_shape[dim]: + param_data = param_data.narrow(dim, 0, size) + if tensor.shape != param_shape: + logger.warning( + "loading tensor of shape %s into " + "parameter '%s' of shape %s", tensor.shape, + key, param_shape) + param_data.copy_(tensor) + state_dict.pop(key) + if state_dict: + raise ValueError( + f"Missing keys {tuple(state_dict)} in loaded state!") + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from safetensors.torch import save_file + + from vllm.distributed import get_tensor_model_parallel_rank + if pattern is None: + pattern = ShardedStateLoader.DEFAULT_PATTERN + rank = get_tensor_model_parallel_rank() + part_idx = 0 + total_size = 0 + state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) + state_dict_part: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + param_size = tensor.nelement() * tensor.element_size() + if max_size is not None and total_size + param_size > max_size: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + part_idx += 1 + total_size = 0 + state_dict_part = {} + state_dict_part[key] = tensor + total_size += param_size + if len(state_dict_part) > 0: + filename = pattern.format(rank=rank, part=part_idx) + save_file( + state_dict_part, + os.path.join(path, filename), + ) + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + possible_config_file_names = ["adapter_config.json"] + + default_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + '.fc1.', + '.fc2.', + '.dense.', + '.query_key_value.', + '.qkv_proj.', + '.dense_h_to_4h.', + '.dense_4h_to_h.', + '.out_proj.', + ] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + self.target_modules = [] + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download(repo_id=qlora_adapter, + filename=file) + break + + if not config_file_path: + raise ValueError( + f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + + def _get_quantized_weights_iterator( + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + if bitsandbytes.__version__ < "0.44.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.44.0.") + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.44.0 via " + "`pip install bitsandbytes>=0.44.0` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict: Dict[str, Any] = {} + + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if not weight_name.lower().endswith(".scb"): + continue + + weight_key = weight_name.lower().replace(".scb", ".qweight") + quant_state_dict[weight_key] = weight_tensor + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + + if not weight_name.endswith((".weight", ".bias")): + continue + + qweight_name = weight_name.replace(".weight", ".qweight") + + if qweight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield qweight_name, weight_tensor + else: + yield weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith((".weight", ".bias")): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in weight_name: + temp_state_dict[weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + + if not weight_name.endswith((".weight", ".bias")): + continue + + if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ + in temp_state_dict) or \ + (f"{weight_name}.quant_state.bitsandbytes__fp4" \ + in temp_state_dict): + quant_state = _parse_quant_state(weight_name, temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + + if any(target_module in weight_name for target_module in + self.target_modules) and weight_name.endswith(".weight"): + weight_name = weight_name.replace(".weight", ".qweight") + + if any(module in weight_name + for module in self.column_parallel_weights_modules): + + total_size = weight_tensor.size(-1) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[..., + start_index:end_index] + + else: + total_size = weight_tensor.size(0) + start_index = total_size // tp_size * tp_rank + end_index = total_size // tp_size * (tp_rank + 1) + weight_sub_tensor = weight_tensor[start_index:end_index, + ...] + + # bitsandbytes requires data in GPU + if weight_sub_tensor.is_cuda: + loaded_weight = weight_sub_tensor + else: + loaded_weight = weight_sub_tensor.cuda() + + # remove the following after the issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 + if loaded_weight.is_contiguous() is False: + loaded_weight = loaded_weight.contiguous() + + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, 'load_weights'): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(model).__name__}.") + + if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + raise AttributeError( + f"Model {type(model).__name__} does not support BitsAndBytes " + "quantization yet.") + + if len(self.target_modules) == 0: + if hasattr(model, 'default_bitsandbytes_target_modules'): + self.target_modules = model.default_bitsandbytes_target_modules + else: + self.target_modules = self.default_target_modules + + if hasattr(model, 'column_parallel_weights_modules'): + self.column_parallel_weights_modules = \ + model.column_parallel_weights_modules + else: + self.column_parallel_weights_modules = [] + + self.model_type = type(model).__name__ + + logger.info("Loading weights with BitsAndBytes quantization. " + " May take a while ...") + + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get('quant_method') + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + # The quant_states in pre_quantized models cannot work with a split + # weight tensor. So TP does not work with pre_quantized bnb models. + if pre_quant and get_tensor_model_parallel_world_size() > 1: + raise ValueError( + "Prequant BitsAndBytes models with TP is not supported." + "Please try with PP.") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get('load_in_8bit', False) + + qweight_iterator, quant_state_dict = \ + self._get_quantized_weights_iterator( + model_config.model, model_config.revision, pre_quant, load_8bit) + + model.load_weights(qweight_iterator) + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, index + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace( + shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model.") + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = math.prod( + quant_state.shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config) + + self._load_weights(model_config, model) + + return model.eval() + + +class GGUFModelLoader(BaseModelLoader): + """ + Model loader that can load GGUF files. This is useful for loading models + that are quantized with GGUF and saved in the GGUF format. This loader + supports loading both full models and sharded models. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _prepare_weights(self, model_name_or_path: str): + if os.path.isfile(model_name_or_path): + return model_name_or_path + else: + raise ValueError(f"{model_name_or_path} is not a file.") + + def _get_gguf_weights_map(self, model_config: ModelConfig): + """ + GGUF uses this naming convention for their tensors from HF checkpoint: + `blk.N.BB.weight` and `blk.N.BB.bias` + where N signifies the block number of a layer, and BB signifies the + attention/mlp layer components. + See "Standardized tensor names" in + https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. + """ + config = model_config.hf_config + model_type = config.model_type + # hack: ggufs have a different name than transformers + if model_type == "cohere": + model_type = "command-r" + arch = None + for key, value in gguf.MODEL_ARCH_NAMES.items(): + if value == model_type: + arch = key + break + if arch is None: + raise RuntimeError(f"Unknown gguf model_type: {model_type}") + num_layers = config.num_hidden_layers + name_map = gguf.get_tensor_name_map(arch, num_layers) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config(config) + state_dict = dummy_model.state_dict() + + gguf_to_hf_name_map = {} + for hf_name in state_dict: + name, suffix = hf_name.rsplit(".", 1) + gguf_name = name_map.get_name(name) + gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + return gguf_to_hf_name_map + + def _get_weights_iterator( + self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + return gguf_quant_weights_iterator(model_name_or_path, + gguf_to_hf_name_map) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + + def load_model(self, *, model_config: ModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig) -> nn.Module: + + local_model_path = self._prepare_weights(model_config.model) + gguf_weights_map = self._get_gguf_weights_map(model_config) + # we can only know if tie word embeddings after mapping weights + if "lm_head.weight" in get_gguf_extra_tensor_names( + local_model_path, gguf_weights_map): + model_config.hf_config.update({"tie_word_embeddings": True}) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config) + model.load_weights( + self._get_weights_iterator(local_model_path, gguf_weights_map)) + return model + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.TENSORIZER: + return TensorizerLoader(load_config) + + if load_config.load_format == LoadFormat.SHARDED_STATE: + return ShardedStateLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + if load_config.load_format == LoadFormat.GGUF: + return GGUFModelLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py new file mode 100644 index 0000000..00c82fb --- /dev/null +++ b/vllm/model_executor/model_loader/neuron.py @@ -0,0 +1,240 @@ +"""Utilities for selecting and loading neuron models.""" +import copy +import importlib +import os +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import transformers +from transformers import PretrainedConfig + +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + SequenceOutput) + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "f32", + "half": "f16", + "float16": "f16", + "bfloat16": "bf16", + "float": "f32", + "float32": "f32", + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", +} + +# Models supported by Neuron. +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { + "LlamaForCausalLM": ("transformers_neuronx.llama.model", + "LlamaForSampling", "LlamaForCausalLM"), + "MistralForCausalLM": ("transformers_neuronx.mistral.model", + "MistralForSampling", "MistralForCausalLM") +} + + +class NeuronCasualLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + on_device_sampling_disabled: bool = False) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + + self.on_device_sampling_disabled = on_device_sampling_disabled + if self.on_device_sampling_disabled: + # Use default sampler + self.sampler = Sampler() + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + ) -> torch.Tensor: + logits = self.model(input_ids, + cache_ids=positions, + start_ids=input_block_ids) + return logits + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + if self.on_device_sampling_disabled: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + # On-device sampling outputs the token ids directly. + sampled_token_ids = logits.flatten() + next_tokens = [] + sample_idx = 0 + for seq_group in sampling_metadata.seq_groups: + samples = [] + for seq_id in seq_group.seq_ids: + token_id = sampled_token_ids[sample_idx].item() + samples.append( + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs={token_id: Logprob(token_id)})) + sample_idx += 1 + next_tokens.append( + CompletionSequenceGroupOutput(samples=samples, + prompt_logprobs=None)) + + return SamplerOutput(outputs=next_tokens) + + def load_weights(self, model_name_or_path: str, **kwargs): + arch = _get_model_architecture(self.config) + neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( + _NEURON_SUPPORTED_MODELS[arch]) + neuronx_module = importlib.import_module(neuronx_module_path) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) + + split_model_dir = f"{model_name_or_path}-split" + if _is_pretrained_neuron_checkpoint(model_name_or_path): + split_model_dir = model_name_or_path + elif not os.path.exists(f"{model_name_or_path}-split"): + hf_model_cls = getattr(transformers, hf_model_cls_name) + from transformers_neuronx.module import save_pretrained_split + + hf_model = hf_model_cls.from_pretrained(model_name_or_path, + low_cpu_mem_usage=True) + save_pretrained_split(hf_model, f"{model_name_or_path}-split") + + self.model = neuronx_model_cls.from_pretrained(split_model_dir, + **kwargs) + self.model.to_neuron() + + +def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool: + # Checking if the neuron checkpoint is saved in the old format. + if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): + return True + # Checking if the neuron checkpoint is saved in the new format. + pretrained_split_files = ["config.json", "generation_config.json"] + pretrained_split_format = ".safetensors" + for file in pretrained_split_files: + file_path = os.path.join(model_name_or_path, file) + if not os.path.isfile(file_path): + return False + for file in os.listdir(model_name_or_path): + if file.endswith(pretrained_split_format): + return True + return False + + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _NEURON_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on Neuron " + f"for now. Supported architectures: " + f"{list(_NEURON_SUPPORTED_MODELS.keys())}") + + +def _get_buckets(env: str, default_value: List[int]) -> List[int]: + env_value = os.getenv(env) + if env_value is None: + return default_value + buckets_remove_empty = filter( + lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) + buckets_int = map(int, buckets_remove_empty) + buckets_list = list(buckets_int) + return buckets_list + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + quant_config = dict( + dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + quantize_method="vector_dynamic") + neuron_quantization_config_builder = lambda quant: get_quantization_config( + quant).from_config(quant_config).get_quant_method(None, "") + # TODO: Add Paged attention config to the default neuron arguments. + default_neuron_args = dict( + collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + quant=neuron_quantization_config_builder(model_config.quantization) + if model_config.quantization else None, + continuous_batching=continuous_batching_config, + weight_tiling=bool(model_config.quantization), + on_device_generation=_get_neuron_on_device_generation_config( + model_config)) + return default_neuron_args + + +def _get_neuron_on_device_generation_config(model_config: ModelConfig): + if not _is_neuron_on_device_sampling_disabled(model_config): + return copy.deepcopy(model_config.neuron_sampling_params) + return None + + +def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool: + return not getattr(model_config, "neuron_sampling_params", None) + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + from transformers_neuronx.config import NeuronConfig + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return NeuronConfig(**default_neuron_config) + + +def get_neuron_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig) -> nn.Module: + + # Create a model instance. + model = NeuronCasualLM( + model_config.hf_config, + _is_neuron_on_device_sampling_disabled(model_config)) + + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) + + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) + + return model.eval() diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py new file mode 100644 index 0000000..88b7ac4 --- /dev/null +++ b/vllm/model_executor/model_loader/openvino.py @@ -0,0 +1,203 @@ +# ruff: noqa: SIM117 +from pathlib import Path +from typing import List, Optional, Tuple + +import openvino as ov +import torch +from huggingface_hub import HfApi +from openvino._offline_transformations import paged_attention_transformation +from optimum.intel import OVModelForCausalLM +from torch import nn + +import vllm.envs as envs +from vllm.attention.backends.openvino import OpenVINOAttentionMetadata +from vllm.config import DeviceConfig, ModelConfig +from vllm.executor.openvino_executor import is_openvino_cpu +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import (LogitsProcessor, + _prune_hidden_states) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata + +logger = init_logger(__name__) + + +def _flattenize_inputs(inputs): + """ + Helper function for making nested inputs flattens + """ + flatten_inputs = [] + for input_data in inputs: + if input_data is None: + continue + if isinstance(input_data, (list, tuple)): + flatten_inputs.extend(_flattenize_inputs(input_data)) + elif isinstance(input_data, dict): + flatten_inputs.extend(_flattenize_inputs(list( + input_data.values()))) + else: + flatten_inputs.append(input_data) + return flatten_inputs + + +def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type, + is_cpu: bool): + # Apply hardware dependent modifications to KV tensors + for parameter in model.get_parameters(): + input = parameter.get_output_tensor(0) + input_names = input.get_names() + if len(input_names) != 1: + continue + input_name = next(iter(input_names)) + shape = parameter.get_partial_shape() + # use real block size if available, just a placeholder + # to provide the expected rank + num_blocks = ov.Dimension() + block_size = ov.Dimension() + head_size = ov.Dimension() + if input_name.startswith("key_cache."): + cpu_shape = [num_blocks, shape[1], block_size, head_size] + gpu_shape = [num_blocks, shape[1], shape[2], block_size] + elif input_name.startswith("value_cache."): + cpu_shape = [num_blocks, shape[1], block_size, head_size] + gpu_shape = [num_blocks, shape[1], block_size, shape[2]] + else: + continue + parameter.set_partial_shape( + ov.PartialShape(cpu_shape if is_cpu else gpu_shape)) + parameter.set_element_type(kv_cache_dtype) + model.validate_nodes_and_infer_types() + + +def _require_model_export(model_id, revision=None, subfolder=None): + model_dir = Path(model_id) + if subfolder is not None: + model_dir = model_dir / subfolder + if model_dir.is_dir(): + return (not (model_dir / "openvino_model.xml").exists() + or not (model_dir / "openvino_model.bin").exists()) + + hf_api = HfApi() + try: + model_info = hf_api.model_info(model_id, revision=revision or "main") + normalized_subfolder = (None if subfolder is None else + Path(subfolder).as_posix()) + model_files = [ + file.rfilename for file in model_info.siblings + if normalized_subfolder is None + or file.rfilename.startswith(normalized_subfolder) + ] + ov_model_path = ("openvino_model.xml" if normalized_subfolder is None + else f"{normalized_subfolder}/openvino_model.xml") + return (ov_model_path not in model_files + or ov_model_path.replace(".xml", ".bin") not in model_files) + except Exception: + return True + + +class OpenVINOCasualLM(nn.Module): + + def __init__( + self, + ov_core: ov.Core, + model_config: ModelConfig, + device_config: DeviceConfig, + kv_cache_dtype: ov.Type, + ) -> None: + super().__init__() + self.logits_processor = LogitsProcessor( + model_config.hf_config.vocab_size, logits_as_input=True) + self.sampler = Sampler() + + export = _require_model_export(model_config.model) + if export: + logger.warning( + f"Provided model id {model_config.model} does not " # noqa: G004 + "contain OpenVINO IR, the model will be converted to IR with " + "default options. If you need to use specific options for " + "model conversion, use optimum-cli export openvino with " + "desired options.") + else: + logger.warning( + "OpenVINO IR is available for provided model id " # noqa: G004 + f"{model_config.model}. This IR will be used for inference " + "as-is, all possible options that may affect model conversion " + "are ignored.") + + load_in_8bit = envs.VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS + pt_model = OVModelForCausalLM.from_pretrained( + model_config.model, + export=export, + compile=False, + load_in_8bit=load_in_8bit, + trust_remote_code=model_config.trust_remote_code, + ) + + ov_device = envs.VLLM_OPENVINO_DEVICE + paged_attention_transformation(pt_model.model) + _modify_cache_parameters(pt_model.model, kv_cache_dtype, + is_openvino_cpu()) + + ov_compiled = ov_core.compile_model(pt_model.model, ov_device) + self.ov_request = ov_compiled.create_infer_request() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[ov.Tensor, ov.Tensor]], + attn_metadata: OpenVINOAttentionMetadata, + ) -> torch.Tensor: + flatten_kv_cache = _flattenize_inputs(kv_caches) + + inputs = [ + input_ids, + positions, + *flatten_kv_cache, + attn_metadata.past_lens, + attn_metadata.subsequence_begins, + attn_metadata.block_indices, + attn_metadata.block_indices_begins, + attn_metadata.max_context_len, + ] + + self.ov_request.start_async(inputs, share_inputs=True) + self.ov_request.wait() + + logits = torch.from_numpy(self.ov_request.get_tensor("logits").data) + + # TODO: remove 'view' once OpenVINO PA will drop 'seq_len' dimension + return logits.view(-1, logits.shape[-1]) + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + +def get_model( + model_config: ModelConfig, + device_config: DeviceConfig, + kv_cache_dtype: ov.Type, + **kwargs, +) -> torch.nn.Module: + lora_config = kwargs.get("lora_config", None) + ov_core = kwargs.get("ov_core") + if lora_config: + raise ValueError( + "OpenVINO modeling does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + + return OpenVINOCasualLM(ov_core, model_config, device_config, + kv_cache_dtype) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py new file mode 100644 index 0000000..36f33d6 --- /dev/null +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -0,0 +1,480 @@ +import argparse +import dataclasses +import io +import os +import re +import time +from dataclasses import dataclass +from functools import partial +from typing import BinaryIO, Generator, Optional, Tuple, Type, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.config import ModelConfig, ParallelConfig +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.utils import FlexibleArgumentParser + +tensorizer_error_msg = None + +try: + from tensorizer import (DecryptionParams, EncryptionParams, + TensorDeserializer, TensorSerializer) + from tensorizer.stream_io import open_stream + from tensorizer.utils import (convert_bytes, get_mem_usage, + no_init_or_tensor) + + _read_stream, _write_stream = (partial( + open_stream, + mode=mode, + ) for mode in ("rb", "wb+")) +except ImportError as e: + tensorizer_error_msg = str(e) + +__all__ = [ + 'EncryptionParams', 'DecryptionParams', 'TensorDeserializer', + 'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage', + 'no_init_or_tensor', 'TensorizerConfig' +] + +logger = init_logger(__name__) + + +@dataclass +class TensorizerConfig: + tensorizer_uri: str + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + model_class: Optional[Type[torch.nn.Module]] = None + hf_config: Optional[PretrainedConfig] = None + dtype: Optional[Union[str, torch.dtype]] = None + _is_sharded: bool = False + + def __post_init__(self): + # check if the configuration is for a sharded vLLM model + self._is_sharded = isinstance(self.tensorizer_uri, str) \ + and re.search(r'%0\dd', self.tensorizer_uri) is not None + + def _construct_tensorizer_args(self) -> "TensorizerArgs": + tensorizer_args = { + "tensorizer_uri": self.tensorizer_uri, + "vllm_tensorized": self.vllm_tensorized, + "verify_hash": self.verify_hash, + "num_readers": self.num_readers, + "encryption_keyfile": self.encryption_keyfile, + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + return TensorizerArgs(**tensorizer_args) # type: ignore + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + if parallel_config.tensor_parallel_size > 1 \ + and not self._is_sharded: + raise ValueError( + "For a sharded model, tensorizer_uri should include a" + " string format template like '%04d' to be formatted" + " with the rank of the shard") + + def verify_with_model_config(self, model_config: "ModelConfig") -> None: + if (model_config.quantization is not None + and self.tensorizer_uri is not None): + logger.warning( + "Loading a model using Tensorizer with quantization on vLLM" + " is unstable and may lead to errors.") + + def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): + if tensorizer_args is None: + tensorizer_args = self._construct_tensorizer_args() + + return open_stream(self.tensorizer_uri, + **tensorizer_args.stream_params) + + +def load_with_tensorizer(tensorizer_config: TensorizerConfig, + **extra_kwargs) -> nn.Module: + tensorizer = TensorizerAgent(tensorizer_config, **extra_kwargs) + return tensorizer.deserialize() + + +@dataclass +class TensorizerArgs: + tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str, + bytes, os.PathLike, int] + vllm_tensorized: Optional[bool] = False + verify_hash: Optional[bool] = False + num_readers: Optional[int] = None + encryption_keyfile: Optional[str] = None + s3_access_key_id: Optional[str] = None + s3_secret_access_key: Optional[str] = None + s3_endpoint: Optional[str] = None + """ + Args for the TensorizerAgent class. These are used to configure the behavior + of the TensorDeserializer when loading tensors from a serialized model. + + Args: + tensorizer_uri: Path to serialized model tensors. Can be a local file + path or a S3 URI. + vllm_tensorized: If True, indicates that the serialized model is a + vLLM model. This is used to determine the behavior of the + TensorDeserializer when loading tensors from a serialized model. + It is far faster to deserialize a vLLM model as it utilizes + tensorizer's optimized GPU loading. Note that this is now + deprecated, as serialized vLLM models are now automatically + inferred as vLLM models. + verify_hash: If True, the hashes of each tensor will be verified against + the hashes stored in the metadata. A `HashMismatchError` will be + raised if any of the hashes do not match. + num_readers: Controls how many threads are allowed to read concurrently + from the source file. Default is `None`, which will dynamically set + the number of readers based on the number of available + resources and model size. This greatly increases performance. + encryption_keyfile: File path to a binary file containing a + binary key to use for decryption. `None` (the default) means + no decryption. See the example script in + examples/tensorize_vllm_model.py. + s3_access_key_id: The access key for the S3 bucket. Can also be set via + the S3_ACCESS_KEY_ID environment variable. + s3_secret_access_key: The secret access key for the S3 bucket. Can also + be set via the S3_SECRET_ACCESS_KEY environment variable. + s3_endpoint: The endpoint for the S3 bucket. Can also be set via the + S3_ENDPOINT_URL environment variable. + """ + + def __post_init__(self): + self.file_obj = self.tensorizer_uri + self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID + self.s3_secret_access_key = (self.s3_secret_access_key + or envs.S3_SECRET_ACCESS_KEY) + self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL + self.stream_params = { + "s3_access_key_id": self.s3_access_key_id, + "s3_secret_access_key": self.s3_secret_access_key, + "s3_endpoint": self.s3_endpoint, + } + + self.deserializer_params = { + "verify_hash": self.verify_hash, + "encryption": self.encryption_keyfile, + "num_readers": self.num_readers + } + + if self.encryption_keyfile: + with open_stream( + self.encryption_keyfile, + **self.stream_params, + ) as stream: + key = stream.read() + decryption_params = DecryptionParams.from_key(key) + self.deserializer_params['encryption'] = decryption_params + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Tensorizer CLI arguments""" + + # Tensorizer options arg group + group = parser.add_argument_group( + 'tensorizer options', + description=('Options for configuring the behavior of the' + ' tensorizer deserializer when ' + 'load_format=tensorizer is specified when ' + 'initializing an LLMEngine, either via the CLI ' + 'when running the vLLM OpenAI inference server ' + 'with a JSON string passed to ' + '--model-loader-extra-config or as arguments given ' + 'to TensorizerConfig when passed to ' + 'model_loader_extra_config in the constructor ' + 'for LLMEngine.')) + + group.add_argument( + "--tensorizer-uri", + help="Path to serialized model tensors. Can be a local file path," + " or an HTTP(S) or S3 URI.", + ) + group.add_argument( + "--verify-hash", + action="store_true", + help="If enabled, the hashes of each tensor will be verified" + " against the hashes stored in the file metadata. An exception" + " will be raised if any of the hashes do not match.", + ) + group.add_argument( + "--encryption-keyfile", + default=None, + help="The file path to a binary file containing a binary key to " + "use for decryption. Can be a file path or S3 network URI.") + group.add_argument( + "--num-readers", + default=None, + type=int, + help="Controls how many threads are allowed to read concurrently " + "from the source file. Default is `None`, which will dynamically " + "set the number of readers based on the available resources " + "and model size. This greatly increases performance.") + group.add_argument( + "--s3-access-key-id", + default=None, + help="The access key for the S3 bucket. Can also be set via the " + "S3_ACCESS_KEY_ID environment variable.", + ) + group.add_argument( + "--s3-secret-access-key", + default=None, + help="The secret access key for the S3 bucket. Can also be set via " + "the S3_SECRET_ACCESS_KEY environment variable.", + ) + group.add_argument( + "--s3-endpoint", + default=None, + help="The endpoint for the S3 bucket. Can also be set via the " + "S3_ENDPOINT_URL environment variable.", + ) + + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs": + attrs = [attr.name for attr in dataclasses.fields(cls)] + tensorizer_args = cls(**{ + attr: getattr(args, attr) + for attr in attrs if hasattr(args, attr) + }) + return tensorizer_args + + +class TensorizerAgent: + """ + A class for performing tensorizer deserializations specifically for + vLLM models using plaid_mode. Uses TensorizerArgs to configure the + behavior of the TensorDeserializer when loading tensors from a serialized + model. For deserializations of HuggingFace models, TensorDeserializer is + instead used as an iterator directly in the func hf_model_weights_iterator + in vllm/model_executor/model_loader/weight_utils.py + """ + + def __init__(self, tensorizer_config: TensorizerConfig, + quant_config: QuantizationConfig, **extra_kwargs): + if tensorizer_error_msg is not None: + raise ImportError( + "Tensorizer is not installed. Please install tensorizer " + "to use this feature with `pip install vllm[tensorizer]`. " + "Error message: {}".format(tensorizer_error_msg)) + + self.tensorizer_config = tensorizer_config + self.tensorizer_args = ( + self.tensorizer_config._construct_tensorizer_args()) + self.extra_kwargs = extra_kwargs + if extra_kwargs.get("quant_config", None) is not None: + self.quant_config = extra_kwargs["quant_config"] + else: + self.quant_config = quant_config + self.model = self._init_model() + + def _init_model(self): + assert self.tensorizer_config.hf_config is not None + model_args = self.tensorizer_config.hf_config + model_args.torch_dtype = self.tensorizer_config.dtype + assert self.tensorizer_config.model_class is not None + with no_init_or_tensor(): + return self.tensorizer_config.model_class( + config=model_args, + quant_config=self.quant_config, + **self.extra_kwargs) + + def _resize_lora_embeddings(self): + """Modify LoRA embedding layers to use bigger tensors + to allow for adapter added tokens.""" + for child in self.model.modules(): + if (isinstance(child, VocabParallelEmbedding) + and child.weight.shape[0] < + child.num_embeddings_per_partition): + new_weight = torch.empty(child.num_embeddings_per_partition, + child.embedding_dim, + dtype=child.weight.dtype, + device=child.weight.device) + new_weight[:child.weight.shape[0]].copy_(child.weight.data) + new_weight[child.weight.shape[0]:].fill_(0) + child.weight.data = new_weight + + def _check_tensors_on_meta_device(self): + for tensor in self.model.state_dict().values(): + if tensor.device.type == 'meta': + raise ValueError( + "The serialized model contains tensors on the meta device," + " indicating that some tensors were not loaded properly." + " Please check that the parameters of the model being" + " specified match that of the serialized model, such as" + " its quantization.") + + def deserialize(self): + """ + Deserialize the model using the TensorDeserializer. This method is + specifically for vLLM models using tensorizer's plaid_mode. + + The deserializer makes use of tensorizer_args.stream_params + to configure the behavior of the stream when loading tensors from a + serialized model. The deserializer_params are used to configure the + behavior of the TensorDeserializer when loading tensors themselves. + Documentation on these params can be found in TensorizerArgs + + Returns: + nn.Module: The deserialized model. + """ + before_mem = get_mem_usage() + start = time.perf_counter() + with _read_stream( + self.tensorizer_config.tensorizer_uri, + **self.tensorizer_args.stream_params + ) as stream, TensorDeserializer( + stream, + dtype=self.tensorizer_config.dtype, + device=f'cuda:{torch.cuda.current_device()}', + **self.tensorizer_args.deserializer_params) as deserializer: + deserializer.load_into_module(self.model) + end = time.perf_counter() + + total_bytes_str = convert_bytes(deserializer.total_tensor_bytes) + duration = end - start + per_second = convert_bytes(deserializer.total_tensor_bytes / duration) + after_mem = get_mem_usage() + deserializer.close() + logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str, + end - start, per_second) + logger.info("Memory usage before: %s", before_mem) + logger.info("Memory usage after: %s", after_mem) + + self._check_tensors_on_meta_device() + self._resize_lora_embeddings() + del self.model.vllm_tensorized_marker + return self.model.eval() + + +def tensorizer_weights_iterator( + tensorizer_args: "TensorizerArgs" +) -> Generator[Tuple[str, torch.Tensor], None, None]: + logger.warning( + "Deserializing HuggingFace models is not optimized for " + "loading on vLLM, as tensorizer is forced to load to CPU. " + "Consider deserializing a vLLM model instead for faster " + "load times. See the examples/tensorize_vllm_model.py example " + "script for serializing vLLM models.") + + deserializer_args = tensorizer_args.deserializer_params + stream_params = tensorizer_args.stream_params + stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params) + with TensorDeserializer(stream, **deserializer_args, + device="cpu") as state: + for name, param in state.items(): + yield name, param + del state + + +def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: + """ + Infer if the model is a vLLM model by checking the weights for + a vLLM tensorized marker. + + Args: + tensorizer_config: The TensorizerConfig object containing the + tensorizer_uri to the serialized model. + + Returns: + bool: True if the model is a vLLM model, False otherwise. + """ + tensorizer_args = tensorizer_config._construct_tensorizer_args() + deserializer = TensorDeserializer(open_stream( + tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params), + **tensorizer_args.deserializer_params, + lazy_load=True) + if tensorizer_config.vllm_tensorized: + logger.warning( + "Please note that newly serialized vLLM models are automatically " + "inferred as vLLM models, so setting vllm_tensorized=True is " + "only necessary for models serialized prior to this change.") + return True + return ".vllm_tensorized_marker" in deserializer + + +def serialize_vllm_model( + model: nn.Module, + tensorizer_config: TensorizerConfig, +) -> nn.Module: + model.register_parameter( + "vllm_tensorized_marker", + nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) + tensorizer_args = tensorizer_config._construct_tensorizer_args() + + encryption_params = None + if (keyfile := tensorizer_config.encryption_keyfile) is not None: + with open(keyfile, "rb") as f: + key = f.read() + encryption_params = EncryptionParams(key=key) + + output_file = tensorizer_args.tensorizer_uri + if tensorizer_config._is_sharded: + from vllm.distributed import get_tensor_model_parallel_rank + output_file = output_file % get_tensor_model_parallel_rank() + + with _write_stream(output_file, **tensorizer_args.stream_params) as stream: + serializer = TensorSerializer(stream, encryption=encryption_params) + serializer.write_module(model) + serializer.close() + logger.info("Successfully serialized model to %s", str(output_file)) + return model + + +def tensorize_vllm_model(engine_args: EngineArgs, + tensorizer_config: TensorizerConfig, + generate_keyfile: bool = True): + """Utility to load a model and then serialize it with Tensorizer + + Intended to be used separately from running a vLLM server since it + creates its own Engine instance. + """ + engine_config = engine_args.create_engine_config() + tensorizer_config.verify_with_model_config(engine_config.model_config) + tensorizer_config.verify_with_parallel_config( + engine_config.parallel_config) + + # generate the encryption key before creating the engine to support sharding + if generate_keyfile and (keyfile := + tensorizer_config.encryption_keyfile) is not None: + encryption_params = EncryptionParams.random() + with _write_stream( + keyfile, + s3_access_key_id=tensorizer_config.s3_access_key_id, + s3_secret_access_key=tensorizer_config.s3_secret_access_key, + s3_endpoint=tensorizer_config.s3_endpoint, + ) as stream: + stream.write(encryption_params.key) + + engine = LLMEngine.from_engine_args(engine_args) + if tensorizer_config._is_sharded: + # if the engine is a distributed engine (for tensor parallel) then each + # worker shard needs to serialize its part of the model. + engine.model_executor._run_workers( + "save_tensorized_model", + tensorizer_config=tensorizer_config, + ) + else: + # with a single worker, we can get to the underlying model directly + serialize_vllm_model( + engine.model_executor.driver_worker.model_runner.model, + tensorizer_config, + ) diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py new file mode 100644 index 0000000..b95c0b7 --- /dev/null +++ b/vllm/model_executor/model_loader/utils.py @@ -0,0 +1,39 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn + +from vllm.config import ModelConfig +from vllm.model_executor.models import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + mixtral_supported = [ + "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin" + ] + + if (model_config.quantization is not None + and model_config.quantization not in mixtral_supported + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + return ModelRegistry.resolve_model_cls(architectures) + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py new file mode 100644 index 0000000..1e2857e --- /dev/null +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -0,0 +1,682 @@ +"""Utilities for downloading and initializing model weights.""" +import fnmatch +import glob +import hashlib +import json +import os +import tempfile +from collections import defaultdict +from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, + Tuple, Union) + +import filelock +import gguf +import huggingface_hub.constants +import numpy as np +import torch +from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from safetensors.torch import load_file, safe_open, save_file +from tqdm.auto import tqdm + +from vllm.config import LoadConfig, ModelConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (QuantizationConfig, + get_quantization_config) +from vllm.model_executor.layers.quantization.schema import QuantParamSchema +from vllm.platforms import current_platform +from vllm.utils import print_warning_once + +logger = init_logger(__name__) + +# use system-level temp directory for file locks, so that multiple users +# can share the same lock without error. +# lock files in the temp directory will be automatically deleted when the +# system reboots, so users will not complain about annoying lock files +temp_dir = tempfile.gettempdir() + + +def enable_hf_transfer(): + """automatically activates hf_transfer + """ + if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ: + try: + # enable hf hub transfer if available + import hf_transfer # type: ignore # noqa + huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True + except ImportError: + pass + + +enable_hf_transfer() + + +class DisabledTqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir or temp_dir + os.makedirs(os.path.dirname(lock_dir), exist_ok=True) + model_name = model_name_or_path.replace("/", "-") + hash_name = hashlib.sha256(model_name.encode()).hexdigest() + # add hash to avoid conflict with old users' lock files + lock_file_name = hash_name + model_name + ".lock" + # mode 0o666 is required for the filelock to be shared across users + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), + mode=0o666) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config(model_config: ModelConfig, + load_config: LoadConfig) -> QuantizationConfig: + + quant_cls = get_quantization_config(model_config.quantization) + + # GGUF doesn't have config file + if model_config.quantization == "gguf": + return quant_cls.from_config({}) + + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + # some vision model may keep quantization_config in their text_config + hf_text_config = getattr(model_config.hf_config, "text_config", None) + if hf_quant_config is None and hf_text_config is not None: + hf_quant_config = getattr(hf_text_config, "quantization_config", None) + if hf_quant_config is None: + # compressed-tensors uses a compressions_config + hf_quant_config = getattr(model_config.hf_config, "compression_config", + None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + # In case of bitsandbytes/QLoRA, get quant config from the adapter model. + if model_config.quantization == "bitsandbytes": + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + return quant_cls.from_config({"adapter_name_or_path": ""}) + model_name_or_path = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + else: + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, load_config.download_dir): + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) + else: + hf_folder = model_name_or_path + + possible_config_filenames = quant_cls.get_config_filenames() + + # If the quantization config is not found, use the default config. + if not possible_config_filenames: + return quant_cls() + + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in possible_config_filenames) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + + if model_config.quantization == "bitsandbytes": + config["adapter_name_or_path"] = model_name_or_path + elif model_config.quantization == "modelopt": + if config["producer"]["name"] == "modelopt": + return quant_cls.from_config(config) + else: + raise ValueError( + f"Unsupported quantization config" + f" found for {model_config.quantization} in {f}.") + + return quant_cls.from_config(config) + + +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, + ignore_patterns: Optional[Union[str, List[str]]] = None, +) -> str: + """Download model weights from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + allow_patterns (List[str]): The allowed patterns for the + weight files. Files matched by any of the patterns will be + downloaded. + revision (Optional[str]): The revision of the model. + ignore_patterns (Optional[Union[str, List[str]]]): The patterns to + filter out the weight files. Files matched by any of the patterns + will be ignored. + + Returns: + str: The path to the downloaded model weights. + """ + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info("Using model weights format %s", allow_patterns) + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + return hf_folder + + +def download_safetensors_index_file_from_hf( + model_name_or_path: str, + index_file: str, + cache_dir: Optional[str], + revision: Optional[str] = None, +) -> None: + """Download hf safetensors index file from Hugging Face Hub. + + Args: + model_name_or_path (str): The model name or path. + cache_dir (Optional[str]): The cache directory to store the model + weights. If None, will use HF defaults. + revision (Optional[str]): The revision of the model. + """ + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + try: + # Download the safetensors index file. + hf_hub_download( + repo_id=model_name_or_path, + filename=index_file, + cache_dir=cache_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) + # If file not found on remote or locally, we should not fail since + # only some models will have index_file. + except huggingface_hub.utils.EntryNotFoundError: + logger.info("No %s found in remote.", index_file) + except huggingface_hub.utils.LocalEntryNotFoundError: + logger.info("No %s found in local cache.", index_file) + + +# For models like Mistral-7B-v0.3, there are both sharded +# safetensors files and a consolidated safetensors file. +# Passing both of these to the weight loader functionality breaks. +# So, we use the index_file to +# look up which safetensors files should be used. +def filter_duplicate_safetensors_files(hf_weights_files: List[str], + hf_folder: str, + index_file: str) -> List[str]: + # model.safetensors.index.json is a mapping from keys in the + # torch state_dict to safetensors file holding that weight. + index_file_name = os.path.join(hf_folder, index_file) + if not os.path.isfile(index_file_name): + return hf_weights_files + + # Iterate through the weight_map (weight_name: safetensors files) + # to identify weights that we should use. + with open(index_file_name, "r") as f: + weight_map = json.load(f)["weight_map"] + weight_files_in_index = set() + for weight_name in weight_map: + weight_files_in_index.add( + os.path.join(hf_folder, weight_map[weight_name])) + # Filter out any fields that are not found in the index file. + hf_weights_files = [ + f for f in hf_weights_files if f in weight_files_in_index + ] + return hf_weights_files + + +def filter_files_not_needed_for_inference( + hf_weights_files: List[str]) -> List[str]: + """ + Exclude files that are not needed for inference. + + See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + """ + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + return hf_weights_files + + +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +def np_cache_weights_iterator( + model_name_or_path: str, cache_dir: Optional[str], hf_folder: str, + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model np files. + + Will dump the model weights to numpy files if they are not already dumped. + """ + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names: List[str] = [] + for bin_file in tqdm( + hf_weights_files, + desc="Loading np_cache checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + + +def safetensors_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model safetensor files.""" + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + for st_file in tqdm( + hf_weights_files, + desc="Loading safetensors checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + + +def pt_weights_iterator( + hf_weights_files: List[str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Iterate over the weights in the model bin/pt files.""" + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + for bin_file in tqdm( + hf_weights_files, + desc="Loading pt checkpoint shards", + disable=not enable_tqdm, + bar_format=_BAR_FORMAT, + ): + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def get_gguf_extra_tensor_names( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str]) -> List[str]: + reader = gguf.GGUFReader(gguf_file) + expected_gguf_keys = set(gguf_to_hf_name_map.keys()) + exact_gguf_keys = set([tensor.name for tensor in reader.tensors]) + extra_keys = expected_gguf_keys - exact_gguf_keys + return [gguf_to_hf_name_map[key] for key in extra_keys] + + +def gguf_quant_weights_iterator( + gguf_file: str, gguf_to_hf_name_map: Dict[str, str] +) -> Generator[Tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights in the model gguf files and convert + them to torch tensors + """ + + reader = gguf.GGUFReader(gguf_file) + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + + if weight_type.name != "F32": + name = name.replace("weight", "qweight") + param = torch.tensor(weight) + yield name, param + + +def kv_cache_scales_loader( + filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, + model_type: Optional[str]) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + Keep this function in sync with the output of examples/fp8/extract_scales.py + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, + context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + + except FileNotFoundError: + logger.error("File or directory '%s' not found.", filename) + except json.JSONDecodeError: + logger.error("Error decoding JSON in file '%s'.", filename) + except Exception as e: + logger.error("An error occurred while reading '%s': %s", filename, e) + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 for all " + "layers in TP rank %d as an error occurred during loading.", tp_rank) + return [] + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + try: + if param.numel() == 1 and loaded_weight.numel() == 1: + # Sometimes scalar values aren't considered tensors with shapes + # so if both param and loaded_weight are a scalar, + # "broadcast" instead of copy + param.data.fill_(loaded_weight.item()) + else: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})") + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def row_parallel_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Load weights that are row-parallelized.""" + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + + if shard_dim is not None: + shard_size = param.data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] + + +def sharded_weight_loader(shard_axis: int) -> LoaderFunction: + """Create a weight loader that shards the weights along the given axis""" + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + tp_rank = get_tensor_model_parallel_rank() + + shard_size = param.data.shape[shard_axis] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) + + return loader + + +def composed_weight_loader( + loader: LoaderFunction, fn: Callable[[torch.Tensor], + torch.Tensor]) -> LoaderFunction: + """Create a weight loader that post-processes the weights after loading""" + + def composed_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + loader(param, loaded_weight) + param.data.copy_(fn(param)) + return + + return composed_loader + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, + seed: int = 1234, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + + We use per-parameter random seed, so that dummy weights are consistent, + even if the model is partitioned across multiple devices. When the seed + is fixed, the random values generated by this function only depends on + the parameter's number of elements and its data type. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + if current_platform.is_tpu(): + # XLA device does not support torch.Generator() + param.uniform_(low, high) + continue + + generator = torch.Generator(device=param.data.device) + generator.manual_seed(seed) + if torch.finfo(param.data.dtype).bits < 16: + # uniform_ doesn't support < 16-bit datatypes (FP8) + dtype = param.data.dtype + tmp_param = param.data.to(torch.float16) + tmp_param = tmp_param.uniform_(low, high, + generator=generator).to(dtype) + param.data.copy_(tmp_param) + else: + param.uniform_(low, high, generator=generator) + + +def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: + """Remap the name of FP8 k/v_scale parameters. + + This function handles the remapping of FP8 k/v_scale parameter names. + It detects if the given name ends with a suffix and attempts to remap + it to the expected name format in the model. If the remapped name is not + found in the params_dict, a warning is printed and None is returned. + + Args: + name (str): The original loaded checkpoint parameter name. + params_dict (dict): Dictionary containing the model's named parameters. + + Returns: + str: The remapped parameter name if successful, or the original name + if no remapping is needed. + None: If the remapped name is not found in params_dict. + """ + if name.endswith(".kv_scale"): + print_warning_once( + "DEPRECATED. Found kv_scale in the checkpoint. " + "This format is deprecated in favor of separate k_scale and " + "v_scale tensors and will be removed in a future release. " + "Functionally, we will remap kv_scale to k_scale and duplicate " + "k_scale to v_scale") + # NOTE: we remap the deprecated kv_scale to k_scale + remapped_name = name.replace(".kv_scale", ".attn.k_scale") + if remapped_name not in params_dict: + print_warning_once( + f"Found kv_scale in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). kv_scale is " + "not loaded.") + return None + return remapped_name + + possible_scale_names = [".k_scale", ".v_scale"] + for scale_name in possible_scale_names: + if name.endswith(scale_name): + remapped_name = name.replace(scale_name, f".attn{scale_name}") + if remapped_name not in params_dict: + print_warning_once( + f"Found {scale_name} in the checkpoint (e.g. {name}), " + "but not found the expected name in the model " + f"(e.g. {remapped_name}). {scale_name} is " + "not loaded.") + return None + return remapped_name + + # If there were no matches, return the untouched param name + return name diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py new file mode 100644 index 0000000..d663735 --- /dev/null +++ b/vllm/model_executor/models/__init__.py @@ -0,0 +1,23 @@ +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, + SupportsPP, has_inner_state, supports_lora, + supports_multimodal, supports_pp) +from .interfaces_base import (VllmModelForEmbedding, + VllmModelForTextGeneration, is_embedding_model, + is_text_generation_model) +from .registry import ModelRegistry + +__all__ = [ + "ModelRegistry", + "VllmModelForEmbedding", + "is_embedding_model", + "VllmModelForTextGeneration", + "is_text_generation_model", + "HasInnerState", + "has_inner_state", + "SupportsLoRA", + "supports_lora", + "SupportsMultiModal", + "supports_multimodal", + "SupportsPP", + "supports_pp", +] \ No newline at end of file diff --git a/vllm/model_executor/models/__pycache__/__init__.cpython-310.pyc b/vllm/model_executor/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..5ef0ec3 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/arctic.cpython-310.pyc b/vllm/model_executor/models/__pycache__/arctic.cpython-310.pyc new file mode 100644 index 0000000..7eac109 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/arctic.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/baichuan.cpython-310.pyc b/vllm/model_executor/models/__pycache__/baichuan.cpython-310.pyc new file mode 100644 index 0000000..7318751 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/baichuan.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/bart.cpython-310.pyc b/vllm/model_executor/models/__pycache__/bart.cpython-310.pyc new file mode 100644 index 0000000..4c44deb Binary files /dev/null and b/vllm/model_executor/models/__pycache__/bart.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/blip.cpython-310.pyc b/vllm/model_executor/models/__pycache__/blip.cpython-310.pyc new file mode 100644 index 0000000..d540977 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/blip.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/blip2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/blip2.cpython-310.pyc new file mode 100644 index 0000000..382fc21 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/blip2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/bloom.cpython-310.pyc b/vllm/model_executor/models/__pycache__/bloom.cpython-310.pyc new file mode 100644 index 0000000..81e0e5e Binary files /dev/null and b/vllm/model_executor/models/__pycache__/bloom.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/chameleon.cpython-310.pyc b/vllm/model_executor/models/__pycache__/chameleon.cpython-310.pyc new file mode 100644 index 0000000..8f0dc82 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/chameleon.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/chatglm.cpython-310.pyc b/vllm/model_executor/models/__pycache__/chatglm.cpython-310.pyc new file mode 100644 index 0000000..dadaa45 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/chatglm.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/clip.cpython-310.pyc b/vllm/model_executor/models/__pycache__/clip.cpython-310.pyc new file mode 100644 index 0000000..7dccad7 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/clip.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/commandr.cpython-310.pyc b/vllm/model_executor/models/__pycache__/commandr.cpython-310.pyc new file mode 100644 index 0000000..cdfcd5d Binary files /dev/null and b/vllm/model_executor/models/__pycache__/commandr.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/dbrx.cpython-310.pyc b/vllm/model_executor/models/__pycache__/dbrx.cpython-310.pyc new file mode 100644 index 0000000..ddb1e64 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/dbrx.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/decilm.cpython-310.pyc b/vllm/model_executor/models/__pycache__/decilm.cpython-310.pyc new file mode 100644 index 0000000..9479d4e Binary files /dev/null and b/vllm/model_executor/models/__pycache__/decilm.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/deepseek.cpython-310.pyc b/vllm/model_executor/models/__pycache__/deepseek.cpython-310.pyc new file mode 100644 index 0000000..4089242 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/deepseek.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/deepseek_v2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/deepseek_v2.cpython-310.pyc new file mode 100644 index 0000000..cde5da2 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/deepseek_v2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/eagle.cpython-310.pyc b/vllm/model_executor/models/__pycache__/eagle.cpython-310.pyc new file mode 100644 index 0000000..52cb839 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/eagle.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/exaone.cpython-310.pyc b/vllm/model_executor/models/__pycache__/exaone.cpython-310.pyc new file mode 100644 index 0000000..48f287a Binary files /dev/null and b/vllm/model_executor/models/__pycache__/exaone.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/falcon.cpython-310.pyc b/vllm/model_executor/models/__pycache__/falcon.cpython-310.pyc new file mode 100644 index 0000000..90692d5 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/falcon.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/fuyu.cpython-310.pyc b/vllm/model_executor/models/__pycache__/fuyu.cpython-310.pyc new file mode 100644 index 0000000..81ff33d Binary files /dev/null and b/vllm/model_executor/models/__pycache__/fuyu.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gemma.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gemma.cpython-310.pyc new file mode 100644 index 0000000..e2608e5 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gemma.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gemma2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gemma2.cpython-310.pyc new file mode 100644 index 0000000..6b8f9e3 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gemma2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gemma2_embedding.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gemma2_embedding.cpython-310.pyc new file mode 100644 index 0000000..ce06079 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gemma2_embedding.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/glm4.cpython-310.pyc b/vllm/model_executor/models/__pycache__/glm4.cpython-310.pyc new file mode 100644 index 0000000..b6c8e16 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/glm4.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/glm4_vision_encoder.cpython-310.pyc b/vllm/model_executor/models/__pycache__/glm4_vision_encoder.cpython-310.pyc new file mode 100644 index 0000000..1efd14e Binary files /dev/null and b/vllm/model_executor/models/__pycache__/glm4_vision_encoder.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gpt2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gpt2.cpython-310.pyc new file mode 100644 index 0000000..3cc35e5 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gpt2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gpt_bigcode.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gpt_bigcode.cpython-310.pyc new file mode 100644 index 0000000..97c4c4c Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gpt_bigcode.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gpt_j.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gpt_j.cpython-310.pyc new file mode 100644 index 0000000..2025e1d Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gpt_j.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/gpt_neox.cpython-310.pyc b/vllm/model_executor/models/__pycache__/gpt_neox.cpython-310.pyc new file mode 100644 index 0000000..f551e48 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/gpt_neox.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/granite.cpython-310.pyc b/vllm/model_executor/models/__pycache__/granite.cpython-310.pyc new file mode 100644 index 0000000..2117e32 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/granite.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/granitemoe.cpython-310.pyc b/vllm/model_executor/models/__pycache__/granitemoe.cpython-310.pyc new file mode 100644 index 0000000..915ea81 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/granitemoe.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/idefics2_vision_model.cpython-310.pyc b/vllm/model_executor/models/__pycache__/idefics2_vision_model.cpython-310.pyc new file mode 100644 index 0000000..c0eb9ce Binary files /dev/null and b/vllm/model_executor/models/__pycache__/idefics2_vision_model.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/interfaces.cpython-310.pyc b/vllm/model_executor/models/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000..bfc0a43 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/interfaces.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/interfaces_base.cpython-310.pyc b/vllm/model_executor/models/__pycache__/interfaces_base.cpython-310.pyc new file mode 100644 index 0000000..bd8289c Binary files /dev/null and b/vllm/model_executor/models/__pycache__/interfaces_base.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/intern_vit.cpython-310.pyc b/vllm/model_executor/models/__pycache__/intern_vit.cpython-310.pyc new file mode 100644 index 0000000..2c2fe53 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/intern_vit.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/internlm2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/internlm2.cpython-310.pyc new file mode 100644 index 0000000..cdf8002 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/internlm2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/internvl.cpython-310.pyc b/vllm/model_executor/models/__pycache__/internvl.cpython-310.pyc new file mode 100644 index 0000000..c2207f6 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/internvl.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/jais.cpython-310.pyc b/vllm/model_executor/models/__pycache__/jais.cpython-310.pyc new file mode 100644 index 0000000..68b8894 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/jais.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/jamba.cpython-310.pyc b/vllm/model_executor/models/__pycache__/jamba.cpython-310.pyc new file mode 100644 index 0000000..088a165 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/jamba.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/llama.cpython-310.pyc b/vllm/model_executor/models/__pycache__/llama.cpython-310.pyc new file mode 100644 index 0000000..09e72b5 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/llama.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/llama_embedding.cpython-310.pyc b/vllm/model_executor/models/__pycache__/llama_embedding.cpython-310.pyc new file mode 100644 index 0000000..be7ed35 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/llama_embedding.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/llava.cpython-310.pyc b/vllm/model_executor/models/__pycache__/llava.cpython-310.pyc new file mode 100644 index 0000000..b0e994d Binary files /dev/null and b/vllm/model_executor/models/__pycache__/llava.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/llava_next.cpython-310.pyc b/vllm/model_executor/models/__pycache__/llava_next.cpython-310.pyc new file mode 100644 index 0000000..491a949 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/llava_next.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/llava_next_video.cpython-310.pyc b/vllm/model_executor/models/__pycache__/llava_next_video.cpython-310.pyc new file mode 100644 index 0000000..ba9e031 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/llava_next_video.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/llava_onevision.cpython-310.pyc b/vllm/model_executor/models/__pycache__/llava_onevision.cpython-310.pyc new file mode 100644 index 0000000..3a3049e Binary files /dev/null and b/vllm/model_executor/models/__pycache__/llava_onevision.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mamba.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mamba.cpython-310.pyc new file mode 100644 index 0000000..32a4981 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mamba.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mamba_cache.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mamba_cache.cpython-310.pyc new file mode 100644 index 0000000..9ea4aa7 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mamba_cache.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/medusa.cpython-310.pyc b/vllm/model_executor/models/__pycache__/medusa.cpython-310.pyc new file mode 100644 index 0000000..1576df2 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/medusa.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/minicpm.cpython-310.pyc b/vllm/model_executor/models/__pycache__/minicpm.cpython-310.pyc new file mode 100644 index 0000000..f028c81 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/minicpm.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/minicpm3.cpython-310.pyc b/vllm/model_executor/models/__pycache__/minicpm3.cpython-310.pyc new file mode 100644 index 0000000..24638a6 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/minicpm3.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/minicpmv.cpython-310.pyc b/vllm/model_executor/models/__pycache__/minicpmv.cpython-310.pyc new file mode 100644 index 0000000..7db5ee3 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/minicpmv.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mixtral.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mixtral.cpython-310.pyc new file mode 100644 index 0000000..ddfa63f Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mixtral.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mixtral_quant.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mixtral_quant.cpython-310.pyc new file mode 100644 index 0000000..f1fc24f Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mixtral_quant.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mllama.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mllama.cpython-310.pyc new file mode 100644 index 0000000..f41d430 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mllama.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mlp_speculator.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mlp_speculator.cpython-310.pyc new file mode 100644 index 0000000..527a4fd Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mlp_speculator.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/module_mapping.cpython-310.pyc b/vllm/model_executor/models/__pycache__/module_mapping.cpython-310.pyc new file mode 100644 index 0000000..bf4fb66 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/module_mapping.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/molmo.cpython-310.pyc b/vllm/model_executor/models/__pycache__/molmo.cpython-310.pyc new file mode 100644 index 0000000..6619caf Binary files /dev/null and b/vllm/model_executor/models/__pycache__/molmo.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/mpt.cpython-310.pyc b/vllm/model_executor/models/__pycache__/mpt.cpython-310.pyc new file mode 100644 index 0000000..7238d6f Binary files /dev/null and b/vllm/model_executor/models/__pycache__/mpt.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/nemotron.cpython-310.pyc b/vllm/model_executor/models/__pycache__/nemotron.cpython-310.pyc new file mode 100644 index 0000000..54a91ea Binary files /dev/null and b/vllm/model_executor/models/__pycache__/nemotron.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/nvlm_d.cpython-310.pyc b/vllm/model_executor/models/__pycache__/nvlm_d.cpython-310.pyc new file mode 100644 index 0000000..829adaf Binary files /dev/null and b/vllm/model_executor/models/__pycache__/nvlm_d.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/olmo.cpython-310.pyc b/vllm/model_executor/models/__pycache__/olmo.cpython-310.pyc new file mode 100644 index 0000000..db57da9 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/olmo.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/olmoe.cpython-310.pyc b/vllm/model_executor/models/__pycache__/olmoe.cpython-310.pyc new file mode 100644 index 0000000..d428615 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/olmoe.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/opt.cpython-310.pyc b/vllm/model_executor/models/__pycache__/opt.cpython-310.pyc new file mode 100644 index 0000000..1998f13 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/opt.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/orion.cpython-310.pyc b/vllm/model_executor/models/__pycache__/orion.cpython-310.pyc new file mode 100644 index 0000000..3e1747f Binary files /dev/null and b/vllm/model_executor/models/__pycache__/orion.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/paligemma.cpython-310.pyc b/vllm/model_executor/models/__pycache__/paligemma.cpython-310.pyc new file mode 100644 index 0000000..06cef71 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/paligemma.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/persimmon.cpython-310.pyc b/vllm/model_executor/models/__pycache__/persimmon.cpython-310.pyc new file mode 100644 index 0000000..7102bb8 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/persimmon.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/phi.cpython-310.pyc b/vllm/model_executor/models/__pycache__/phi.cpython-310.pyc new file mode 100644 index 0000000..e081a8c Binary files /dev/null and b/vllm/model_executor/models/__pycache__/phi.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/phi3.cpython-310.pyc b/vllm/model_executor/models/__pycache__/phi3.cpython-310.pyc new file mode 100644 index 0000000..55c4858 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/phi3.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/phi3_small.cpython-310.pyc b/vllm/model_executor/models/__pycache__/phi3_small.cpython-310.pyc new file mode 100644 index 0000000..c1d83d4 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/phi3_small.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/phi3v.cpython-310.pyc b/vllm/model_executor/models/__pycache__/phi3v.cpython-310.pyc new file mode 100644 index 0000000..daff7ba Binary files /dev/null and b/vllm/model_executor/models/__pycache__/phi3v.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/phimoe.cpython-310.pyc b/vllm/model_executor/models/__pycache__/phimoe.cpython-310.pyc new file mode 100644 index 0000000..9c26fe1 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/phimoe.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/pixtral.cpython-310.pyc b/vllm/model_executor/models/__pycache__/pixtral.cpython-310.pyc new file mode 100644 index 0000000..2c1458f Binary files /dev/null and b/vllm/model_executor/models/__pycache__/pixtral.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen.cpython-310.pyc new file mode 100644 index 0000000..d032ca4 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen2.cpython-310.pyc new file mode 100644 index 0000000..06ac033 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen2_moe.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen2_moe.cpython-310.pyc new file mode 100644 index 0000000..6c33bbe Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen2_moe.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen2_rm.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen2_rm.cpython-310.pyc new file mode 100644 index 0000000..8b2d8b0 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen2_rm.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen2_vl.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen2_vl.cpython-310.pyc new file mode 100644 index 0000000..de7b1b9 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen2_vl.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen3.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen3.cpython-310.pyc new file mode 100644 index 0000000..11f7142 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen3.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/qwen3_moe.cpython-310.pyc b/vllm/model_executor/models/__pycache__/qwen3_moe.cpython-310.pyc new file mode 100644 index 0000000..6aed6f1 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/qwen3_moe.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/registry.cpython-310.pyc b/vllm/model_executor/models/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000..a3999c8 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/registry.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/siglip.cpython-310.pyc b/vllm/model_executor/models/__pycache__/siglip.cpython-310.pyc new file mode 100644 index 0000000..186536a Binary files /dev/null and b/vllm/model_executor/models/__pycache__/siglip.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/solar.cpython-310.pyc b/vllm/model_executor/models/__pycache__/solar.cpython-310.pyc new file mode 100644 index 0000000..bde0046 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/solar.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/stablelm.cpython-310.pyc b/vllm/model_executor/models/__pycache__/stablelm.cpython-310.pyc new file mode 100644 index 0000000..ba2924f Binary files /dev/null and b/vllm/model_executor/models/__pycache__/stablelm.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/starcoder2.cpython-310.pyc b/vllm/model_executor/models/__pycache__/starcoder2.cpython-310.pyc new file mode 100644 index 0000000..d16115e Binary files /dev/null and b/vllm/model_executor/models/__pycache__/starcoder2.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/ultravox.cpython-310.pyc b/vllm/model_executor/models/__pycache__/ultravox.cpython-310.pyc new file mode 100644 index 0000000..4c598e6 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/ultravox.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/utils.cpython-310.pyc b/vllm/model_executor/models/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..b94a9ba Binary files /dev/null and b/vllm/model_executor/models/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/model_executor/models/__pycache__/xverse.cpython-310.pyc b/vllm/model_executor/models/__pycache__/xverse.cpython-310.pyc new file mode 100644 index 0000000..71a46f3 Binary files /dev/null and b/vllm/model_executor/models/__pycache__/xverse.cpython-310.pyc differ diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py new file mode 100644 index 0000000..30b1f1c --- /dev/null +++ b/vllm/model_executor/models/arctic.py @@ -0,0 +1,562 @@ +"""Inference-only Snowflake Arctic model.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.deepspeedfp import ( + DeepSpeedFPConfig, DeepSpeedFPParameter) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.arctic import ArcticConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +logger = init_logger(__name__) + + +class ArcticMLP(nn.Module): + + def __init__(self, + config: ArcticConfig, + layer_id: int, + expert_id: int = -1, + is_residual_mlp: bool = False, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMLP, self).__init__() + self.hidden_size = config.hidden_size + self.expert_id = expert_id + self.layer_id = layer_id + + self.ffn_dim = config.intermediate_size if not is_residual_mlp \ + else self.hidden_size + + self.w13 = MergedColumnParallelLinear(self.hidden_size, + [self.ffn_dim] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(self.ffn_dim, + self.hidden_size, + bias=False, + reduce_results=reduce_results, + quant_config=quant_config) + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, hidden_states): + gate_up, _ = self.w13(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.w2(hidden_states) + return hidden_states + + +class ArcticMoE(nn.Module): + """ + Model-parallel implementation of Arctic MoE Layer. + """ + + def __init__(self, + config: ArcticConfig, + layer_id: int, + tp_size: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True): + super(ArcticMoE, self).__init__() + + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.hidden_size = config.hidden_size + self.num_experts = config.num_local_experts + self.layer_id = layer_id + self.top_k = config.num_experts_per_tok + self.intermediate_size = config.intermediate_size // self.tp_size + + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 + self.is_quant = isinstance(quant_config, DeepSpeedFPConfig) + self.reduce_results = reduce_results + # Some other parameters + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + if not self.is_moe_layer: + self.mlp = ArcticMLP(config, + layer_id=layer_id, + quant_config=quant_config, + reduce_results=reduce_results) + else: + self.gate = ReplicatedLinear(self.hidden_size, + self.num_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=quant_config) + if self.is_quant: + self.ws = DeepSpeedFPParameter( + torch.Size((self.num_experts, 2 * self.intermediate_size, + self.hidden_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + self.w2s = DeepSpeedFPParameter( + torch.Size((self.num_experts, self.hidden_size, + self.intermediate_size)), + params_dtype=params_dtype, + quant_config=quant_config, + ) + else: + self.ws = nn.Parameter( + torch.empty(self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.ds_dequantize() if self.is_quant else param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + if self.is_quant: + param.ds_quantize_(param_data) + + def local_moe_fused(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + do_normalize = self.top_k > 1 + topk_weights, topk_ids = fused_topk(hidden_states, + router_logits, + self.top_k, + renormalize=do_normalize) + # topk_ids: (num_tokens, k) + if self.is_quant: + if 2 * num_tokens <= self.num_experts: + # If much fewer tokens than experts, use selective dequantize. + ws_dequantized = self.ws.ds_selective_dequantize( + topk_ids.flatten()) + w2s_dequantized = self.w2s.ds_selective_dequantize( + topk_ids.flatten()) + # We gathered the experts to the tokens so update the mapping. + topk_ids = torch.arange( + 0, + topk_ids.numel(), + device=topk_ids.device, + ).reshape(topk_ids.shape) + else: + ws_dequantized = self.ws.ds_dequantize() + w2s_dequantized = self.w2s.ds_dequantize() + + final_hidden_states = fused_experts( + hidden_states, + ws_dequantized if self.is_quant else self.ws, + w2s_dequantized if self.is_quant else self.w2s, + topk_weights, + topk_ids, + inplace=True) + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward(self, hidden_states: torch.Tensor): + if self.is_moe_layer: + final_hidden_states = self.local_moe_fused(hidden_states) + else: + final_hidden_states = self.mlp(hidden_states) + return final_hidden_states + + +class ArcticAttention(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: Optional[int] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + reduce_results=True, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ArcticDecoderLayer(nn.Module): + + def __init__( + self, + config: ArcticConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 + self.use_residual = config.use_residual and is_moe_layer + self.self_attn = ArcticAttention(config, + layer_idx, + cache_config, + quant_config=quant_config) + self.block_sparse_moe = ArcticMoE( + config, + layer_id=layer_idx, + quant_config=quant_config, + reduce_results=(not self.use_residual)) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + if self.use_residual: + self.residual_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.residual_mlp = ArcticMLP(config, + layer_id=layer_idx, + is_residual_mlp=True, + reduce_results=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual_input = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual_input + hidden_states + + residual_attn = hidden_states + if self.use_residual: + hidden_states = self.residual_layernorm(hidden_states) + hidden_states = self.residual_mlp(hidden_states) + residual_mlp = hidden_states + hidden_states = self.post_attention_layernorm(residual_input) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_mlp + hidden_states + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + hidden_states = residual_attn + hidden_states + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual_attn + hidden_states + return hidden_states + + +class ArcticModel(nn.Module): + + def __init__( + self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=self.vocab_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: ArcticDecoderLayer(config, int( + prefix.split(".")[-1]), cache_config, quant_config), + prefix=f"{prefix}.layers") + self._attn_implementation = config._attn_implementation + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class ArcticForCausalLM(nn.Module, SupportsPP): + + def __init__(self, + config: ArcticConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + **kwargs) -> None: + super().__init__() + self.config = config + self.model = ArcticModel(config, cache_config, quant_config) + self.vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.unpadded_vocab_size = config.vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + mlp_params_mapping: List[Tuple[str, str, int]] = [] + expert_params_mapping: List[Tuple[str, str, int]] = [] + num_layers = self.config.num_hidden_layers + + for layer in range(num_layers): + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.residual_mlp.w13.weight", + f"layers.{layer}.residual_mlp.w3.weight", 1)) + if layer % 2 == 0: + # MLP layers + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w1.weight", 0)) + mlp_params_mapping.append( + (f"layers.{layer}.block_sparse_moe.mlp.w13.weight", + f"layers.{layer}.block_sparse_moe.mlp.w3.weight", 1)) + else: + # MoE layers + for expert_id in range(self.config.num_local_experts): + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w1.weight", expert_id)) + expert_params_mapping.append( + ("w2s", f"experts.{expert_id}.w2.weight", expert_id)) + expert_params_mapping.append( + ("ws", f"experts.{expert_id}.w3.weight", expert_id)) + + params_dict = dict(self.named_parameters()) + + logger.info( + "It will take ~10 minutes loading from the 16-bit weights. " + "Alternatively, use the prequantized 8-bit weights of arctic " + "and set load-format to `sharded_state` will accelerate loading.") + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id in mlp_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, shard_id \ + in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py new file mode 100644 index 0000000..54ed548 --- /dev/null +++ b/vllm/model_executor/models/baichuan.py @@ -0,0 +1,463 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BaiChuan model compatible with HuggingFace weights.""" +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BaiChuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class BaiChuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.postion_embedding = position_embedding + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # pylint: disable=invalid-name + self.W_pack = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + # Create the alibi slopes and slice them. + if self.postion_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + quant_config=quant_config) + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.postion_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class BaiChuanDecoderLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = BaiChuanAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + position_embedding=position_embedding, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + self.mlp = BaiChuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BaiChuanModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BaiChuanDecoderLayer(config, position_embedding, + cache_config, quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "W_pack": ["W_pack"], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "W_pack", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + position_embedding: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = BaiChuanModel(config, position_embedding, cache_config, + quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name == "lm_head.weight": + # Unlike Baichuan, Baichuan2 normalizes the head weights. + # Refer to: + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 + # Distinguish between Baichuan and Baichuan2 by checking the + # vocab size. This is suggested by + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 + is_baichuan2 = self.config.vocab_size == 125696 + if is_baichuan2: + loaded_weight = torch.nn.functional.normalize( + loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 13B and Baichuan2 7B/13B.""" + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + if config.hidden_size == 4096: # baichuan2 7b + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) + else: # baichuan 13b, baichuan2 13b + super().__init__(config, "ALIBI", cache_config, quant_config, + lora_config) + + +class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 7B.""" + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__(config, "ROPE", cache_config, quant_config, + lora_config) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py new file mode 100644 index 0000000..cbdacf7 --- /dev/null +++ b/vllm/model_executor/models/bart.py @@ -0,0 +1,1003 @@ +# Derived from BART implementation posted on HuggingFace; license below: +# +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BART model.""" +import math +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import BartConfig +from transformers.utils import logging + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +logger = logging.get_logger(__name__) + + +def get_bsz_seq_len(input_ids): + shp = input_ids.shape + ndim = len(shp) + if ndim == 1: + return 1, input_ids.numel() + else: + return shp[:2] + + +class BartLearnedPositionalEmbedding(VocabParallelEmbedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # Bart is set up so that if padding_idx is + # specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. + # Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward( + self, + positions: torch.Tensor, + attn_type: AttentionType, + ) -> torch.Tensor: + """`input_ids' shape is expected to be [bsz x seqlen].""" + + assert attn_type != AttentionType.ENCODER_DECODER + + return super().forward(positions + self.offset) + + +class BartScaledWordEmbedding(VocabParallelEmbedding): + """ + This module overrides VocabParallelEmbedding's + forward by multiplying with embeddings scale. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) * self.embed_scale + + +class BartParallelLMHead(ParallelLMHead): + """ + This module overrides ParallelLMHead's + forward by dividing by embeddings scale, + yielding effectively the inverse of + BartScaledWordEmbedding + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + embed_scale: float = 1.0): + super().__init__(num_embeddings, embedding_dim) + self.embed_scale = embed_scale + + def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + return super().forward(input_ids) / self.embed_scale + + +class BartEncoderAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER) + + output, _ = self.out_proj(attn_output) + return output + + +class BartDecoderSelfAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.DECODER) + + output, _ = self.out_proj(attn_output) + return output + + +class BartCrossAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + config: Optional[BartConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.embed_dim = embed_dim + self.total_num_heads = num_heads + self.total_num_kv_heads = self.total_num_heads + self.head_dim = embed_dim // num_heads + self.config = config + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads}).") + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + + # (afeldman-nm 2024/07/22) TODO: + # Need a more efficient solution for q/k/v + qkv_dec, _ = self.qkv_proj(decoder_hidden_states) + q, _, _ = qkv_dec.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + if encoder_hidden_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(encoder_hidden_states) + _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) + + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + + output, _ = self.out_proj(attn_output) + return output + + +class BartEncoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartEncoderAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.activation_fn = get_act_fn(config.activation_function, + quant_config) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + hidden_states + torch.Tensor of *encoder* input embeddings. + kv_cache: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Encoder layer output torch.Tensor + """ + residual = hidden_states + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + if hidden_states.dtype == torch.float16 and ( + torch.isinf(hidden_states).any() + or torch.isnan(hidden_states).any()): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp(hidden_states, + min=-clamp_value, + max=clamp_value) + + return hidden_states + + +class BartDecoderLayer(nn.Module): + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = BartDecoderSelfAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + config=config, + cache_config=cache_config, + quant_config=quant_config) + self.activation_fn = get_act_fn(config.activation_function, + quant_config) + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + ''' + afeldman-nm: personally I would call this "cross-attention", + however I left the name as "encoder_attn" to maintain consistency + with the name of the pretrained weights. + ''' + self.encoder_attn = BartCrossAttention( + self.embed_dim, + config.decoder_attention_heads, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + + ffn_hidden_size = self.embed_dim + ffn_intermediate_size = config.encoder_ffn_dim + ffn_has_bias = True + self.fc1 = ColumnParallelLinear( + ffn_hidden_size, + ffn_intermediate_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + ffn_intermediate_size, + ffn_hidden_size, + bias=ffn_has_bias, + quant_config=quant_config, + ) + + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + decoder_hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + encoder_hidden_states: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + decoder_hidden_states + torch.Tensor of *decoder* input embeddings. + kv_cache: + KV cache tensor + attn_metadata: + vLLM Attention metadata structure + encoder_hidden_states + torch.Tensor of *encoder* input embeddings. + Returns: + Decoder layer output torch.Tensor + """ + residual = decoder_hidden_states + + # Self Attention + hidden_states = self.self_attn(hidden_states=decoder_hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + + residual = hidden_states + + hidden_states = self.encoder_attn( + decoder_hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + fc1_out, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(fc1_out) + + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states + + +class BartEncoder(nn.Module): + """ + Transformer encoder consisting of *config.encoder_layers* + self attention layers. Each layer is a [`BartEncoderLayer`]. + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None): + super().__init__() + + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + embed_dim = config.d_model + self.max_source_positions = config.max_position_embeddings + embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + embed_dim, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + embed_dim, + ) + self.layers = nn.ModuleList( + [BartEncoderLayer(config,cache_config,quant_config) \ + for _ in range(config.encoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(embed_dim) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *encoder* input sequence tokens. + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Decoder output torch.Tensor + """ + # retrieve input_ids and inputs_embeds + + input_ids = input_ids.view(-1, input_ids.shape[-1]) + inputs_embeds = self.embed_tokens(input_ids) + + embed_pos = self.embed_positions( + positions, + AttentionType.ENCODER, + ) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + for idx, encoder_layer in enumerate(self.layers): + hidden_states = encoder_layer( + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + + return hidden_states + + +class BartDecoder(nn.Module): + """ + Transformer decoder consisting of *config.decoder_layers* layers. + Each layer is a [`BartDecoderLayer`] + Args: + config: BartConfig + embed_tokens (nn.Embedding): output embedding + """ + + def __init__( + self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + embed_tokens: Optional[nn.Embedding] = None, + ): + super().__init__() + self.cache_config = cache_config + self.quant_config = quant_config + self.lora_config = lora_config + self.max_target_positions = config.max_position_embeddings + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.embed_tokens = BartScaledWordEmbedding(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + if embed_tokens is not None: + self.embed_tokens.weight = embed_tokens.weight + + self.embed_positions = BartLearnedPositionalEmbedding( + config.max_position_embeddings, + config.d_model, + ) + + self.layers = nn.ModuleList( + [BartDecoderLayer(config,cache_config,quant_config) \ + for _ in range(config.decoder_layers)]) + + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + def forward(self, decoder_input_ids: torch.Tensor, + decoder_positions: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + decoder_input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + decoder_positions + Positions of *decoder* input sequence tokens. + encoder_hidden_states: + Tensor of encoder output embeddings + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Decoder output torch.Tensor + """ + + inputs_embeds = self.embed_tokens(decoder_input_ids) + + # embed positions + embed_pos = self.embed_positions( + decoder_positions, + AttentionType.DECODER, + ) + embed_pos = embed_pos.to(inputs_embeds.device) + + hidden_states = inputs_embeds + embed_pos + hidden_states = self.layernorm_embedding(hidden_states) + + # decoder layers + + for idx, decoder_layer in enumerate(self.layers): + hidden_states = decoder_layer( + decoder_hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + encoder_hidden_states=encoder_hidden_states, + ) + + return hidden_states + + +class BartModel(nn.Module): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", "decoder.embed_tokens.weight" + ] + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + super().__init__() + + self.config = config + + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.encoder = BartEncoder(config, + cache_config, + quant_config=quant_config) + self.decoder = BartDecoder(config, + cache_config, + quant_config=quant_config) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata) -> torch.Tensor: + r""" + Args: + input_ids + Indices of *decoder* input sequence tokens in the vocabulary. + Padding will be ignored by default should you + provide it. + positions + Positions of *decoder* input sequence tokens. + encoder_input_ids + Indices of *encoder* input sequence tokens in the vocabulary. + encoder_positions: + Positions of *encoder* input sequence tokens. + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Model output torch.Tensor + """ + + encoder_hidden_states = None + + if encoder_input_ids.numel() > 0: + # Run encoder attention if a non-zero number of encoder tokens + # are provided as input + encoder_hidden_states = self.encoder(input_ids=encoder_input_ids, + positions=encoder_positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + # decoder outputs consists of + # (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + decoder_input_ids=input_ids, + decoder_positions=positions, + encoder_hidden_states=encoder_hidden_states, + kv_caches=kv_caches, + attn_metadata=attn_metadata) + + return decoder_outputs + + +class BartForConditionalGeneration(nn.Module): + base_model_prefix = "model" + + def __init__(self, + config: BartConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + + super().__init__() + # currently all existing BART models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.config = config + self.model = BartModel(config, + cache_config, + quant_config, + lora_config=lora_config) + + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + embed_scale = math.sqrt( + config.d_model) if config.scale_embedding else 1.0 + + self.lm_head = BartParallelLMHead(config.vocab_size, + config.d_model, + embed_scale=embed_scale) + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + *, + encoder_input_ids: torch.Tensor, + encoder_positions: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + r""" + Args: + input_ids + torch.Tensor of *decoder* input token ids. + positions + torch.Tensor of *decoder* position indices. + encoder_input_ids + torch.Tensor of *encoder* input token ids. + encoder_positions + torch.Tensor of *encoder* position indices + kv_caches: + Layer-wise list of KV cache tensors + attn_metadata: + vLLM Attention metadata structure + Returns: + Output torch.Tensor + """ + return self.model(input_ids, positions, encoder_input_ids, + encoder_positions, kv_caches, attn_metadata) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + stacked_params_mapping = { + "q_proj": { + "param_name": "qkv_proj", + "shard_id": "q", + }, + "k_proj": { + "param_name": "qkv_proj", + "shard_id": "k", + }, + "v_proj": { + "param_name": "qkv_proj", + "shard_id": "v", + }, + } + + params_mapping = { + "beta": "bias", + "gamma": "weight", + "LayerNorm": "layernorm", + } + + def _rename_key(self, key: str): + prefix = f"{self.base_model_prefix}." + key = key[len(prefix):] if key.startswith(prefix) else key + + for src, dst in self.params_mapping.items(): + key = key.replace(src, dst) + + return key + + def _rename_stacked_param( + self, + name: str, + ) -> Tuple[str, Optional[str]]: + for key, mapping in self.stacked_params_mapping.items(): + if key in name: + name = name.replace(key, mapping["param_name"]) + return name, mapping["shard_id"] + return name, None + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + model_params_dict = dict(self.model.named_parameters()) + top_params_dict = dict(self.named_parameters()) + + weights_tuple_list = list(weights) + + shared_embedding_weight = None + shared_embedding_shard_id = None + + for name, loaded_weight in weights_tuple_list: + + name = self._rename_key(name) + name, shard_id = self._rename_stacked_param(name) + + if ('shared.weight' in name + or 'encoder.embed_tokens.weight' in name + or 'decoder.embed_tokens.weight' in name + or 'lm_head.weight' in name): + assert shared_embedding_weight is None, ( + "Conflicting embedding weights.") + shared_embedding_weight = loaded_weight + shared_embedding_shard_id = shard_id + else: + # Skip the specific downstream task weight. + if name.startswith('cls.'): + continue + # use Pooler instead. + if name.startswith('pooler.'): + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in model_params_dict: + continue + + param = model_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if shard_id: + weight_loader(param, loaded_weight, shard_id) + else: + weight_loader(param, loaded_weight) + + # Assign shared weight values + encoder_in_param = model_params_dict['encoder.embed_tokens.weight'] + encoder_in_weight_loader = getattr(encoder_in_param, "weight_loader", + default_weight_loader) + + decoder_in_param = model_params_dict['decoder.embed_tokens.weight'] + decoder_in_weight_loader = getattr(decoder_in_param, "weight_loader", + default_weight_loader) + + lm_head_in_param = top_params_dict['lm_head.weight'] + lm_head_in_weight_loader = getattr(lm_head_in_param, "weight_loader", + default_weight_loader) + + assert shared_embedding_weight is not None + + if shared_embedding_shard_id: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight, + shared_embedding_shard_id) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight, + shared_embedding_shard_id) + else: + encoder_in_weight_loader(encoder_in_param, shared_embedding_weight) + decoder_in_weight_loader(decoder_in_param, shared_embedding_weight) + lm_head_in_weight_loader(lm_head_in_param, shared_embedding_weight) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py new file mode 100644 index 0000000..228908d --- /dev/null +++ b/vllm/model_executor/models/blip.py @@ -0,0 +1,423 @@ +"""Minimal implementation of BlipVisionModel intended to be only used +within a vision language model.""" +from typing import Iterable, Optional, Tuple, Union + +import torch +import torch.nn as nn +from PIL import Image +from transformers import Blip2VisionConfig, BlipVisionConfig +from transformers.models.blip.modeling_blip import BlipAttention + +from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import SequenceData + +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + +USE_XFORMERS_OPS = True + +def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_blip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_blip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_blip_image_feature_size( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: + return get_blip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_blip_image_tokens( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig]) -> int: + return get_blip_image_feature_size(hf_config) + + +def dummy_seq_data_for_blip( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_blip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_image_for_blip( + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def input_processor_for_blip( + model_config: ModelConfig, + hf_config: Union[BlipVisionConfig, Blip2VisionConfig], + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_feature_size = get_blip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/blip/modeling_blip.py#L164 # noqa +class BlipVisionEmbeddings(nn.Module): + + def __init__(self, config: BlipVisionConfig): + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + ) + + self.num_patches = get_blip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + + position_embeds = self.position_embedding.to(target_dtype) + embeddings = embeddings + position_embeds[:, :embeddings.size(1), :] + + return embeddings + + +class BlipParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + query_states = query_states.transpose(1,2) + key_states = key_states.transpose(1,2) + value_states = value_states.transpose(1,2) + attn = (query_states @ (key_states * self.scale).permute(0,1,3,2)).float() + attn = torch.softmax(attn, dim=-1).to(value_states.dtype) + out = attn @ value_states + out = out.transpose(1,2).reshape(bsz, tgt_len, -1) + # TODO: use use F.attention when supported + # out = xops.memory_efficient_attention_forward(query_states, + # key_states, + # value_states, + # p=self.dropout, + # scale=self.scale) + # out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.projection(out) + + return attn_output, None + + +class BlipMLP(nn.Module): + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + self.config = config + + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class BlipEncoderLayer(nn.Module): + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + # fallback to sdpa attention if tp unavailable + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = BlipParallelAttention(config, + quant_config=quant_config) + else: + # Blip doesn't have SDPA attention implemented in transformers + # use eager attention instead for cpu backend + self.self_attn = BlipAttention(config) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = BlipMLP(config, quant_config=quant_config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class BlipEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`BlipEncoderLayer`]. + + Args: + config: BlipConfig + """ + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + BlipEncoderLayer(config=config, quant_config=quant_config) + for _ in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class BlipVisionModel(nn.Module): + config_class = BlipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + + self.config = config + + self.embeddings = BlipVisionEmbeddings(config) + self.encoder = BlipEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + ) + + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + hidden_states = self.embeddings(pixel_values) + hidden_states = self.encoder(inputs_embeds=hidden_states) + + if self.post_layernorm is None: + return hidden_states + + return self.post_layernorm(hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] + params_dict = dict(self.named_parameters()) + layer_count = len(self.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in BlipVisionModel + if (name.startswith("post_layernorm") + and self.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("encoder.layers"): + layer_idx = int(name.split(".")[2]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py new file mode 100644 index 0000000..3ab2357 --- /dev/null +++ b/vllm/model_executor/models/blip2.py @@ -0,0 +1,690 @@ +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +from transformers import (Blip2Config, Blip2QFormerConfig, Blip2VisionConfig, + apply_chunking_to_forward) + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors, SequenceData + +from .blip import (BlipVisionModel, dummy_image_for_blip, + get_max_blip_image_tokens) +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + merge_multimodal_embeddings) + +# We use this internally as placeholders since there is no image token +# defined on the HuggingFace repo +BLIP2_IMAGE_TOKEN = "" +BLIP2_IMAGE_TOKEN_ID = 50265 + + +class Blip2ImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class Blip2ImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +Blip2ImageInputs = Union[Blip2ImagePixelInputs, Blip2ImageEmbeddingInputs] + + +class Blip2QFormerMultiHeadAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + ) -> None: + super().__init__() + + self.config = config + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of " + f"the number of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = (config.hidden_size // + config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.scaling = self.attention_head_size**-0.5 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + kv_hidden_size = config.encoder_hidden_size + else: + kv_hidden_size = config.hidden_size + self.key = nn.Linear(kv_hidden_size, self.all_head_size) + self.value = nn.Linear(kv_hidden_size, self.all_head_size) + + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + if self.position_embedding_type != "absolute": + raise NotImplementedError("Unsupported position_embedding_type: " + f"{self.position_embedding_type}") + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + x = x.view(*x.size()[:-1], self.num_attention_heads, + self.attention_head_size) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ): + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores( + self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores( + self.value(encoder_hidden_states)) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + attention_scores = torch.matmul(query_layer, + key_layer.transpose(-1, -2)) + attention_probs = torch.softmax(attention_scores * self.scaling, + dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + context_layer = context_layer.view(*context_layer.size()[:-2], + self.all_head_size) + + return context_layer + + +class Blip2QFormerSelfOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerAttention(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + is_cross_attention: bool = False, + ) -> None: + super().__init__() + + self.attention = Blip2QFormerMultiHeadAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=is_cross_attention, + ) + + self.output = Blip2QFormerSelfOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.Tensor]: + self_output = self.attention( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + attention_output = self.output(self_output, hidden_states) + + return attention_output + + +class Blip2QFormerIntermediate(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class Blip2QFormerOutput(nn.Module): + + def __init__(self, config: Blip2QFormerConfig) -> None: + super().__init__() + + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward( + self, + hidden_states: torch.Tensor, + input_tensor: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class Blip2QFormerLayer(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + layer_idx: int, + ) -> None: + super().__init__() + + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = Blip2QFormerAttention(config, + quant_config=quant_config, + cache_config=cache_config) + + self.layer_idx = layer_idx + + if layer_idx % config.cross_attention_frequency == 0: + self.crossattention = Blip2QFormerAttention( + config, + quant_config=quant_config, + cache_config=cache_config, + is_cross_attention=True) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate_query = Blip2QFormerIntermediate(config) + self.output_query = Blip2QFormerOutput(config) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ): + attention_output = self.attention(hidden_states) + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + query_attention_output = self.crossattention( + query_attention_output, + encoder_hidden_states=encoder_hidden_states, + ) + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], + dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + + return layer_output + + def feed_forward_chunk(self, + attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query( + self, attention_output: torch.Tensor) -> torch.Tensor: + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class Blip2QFormerEncoder(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + ) -> None: + super().__init__() + + self.config = config + + self.layer = nn.ModuleList([ + Blip2QFormerLayer(config, + quant_config=quant_config, + cache_config=cache_config, + layer_idx=layer_idx) + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + query_length: int, + ) -> torch.Tensor: + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + + hidden_states = layer_module( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return hidden_states + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1025 +class Blip2QFormerModel(nn.Module): + + def __init__( + self, + config: Blip2QFormerConfig, + *, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + ) -> None: + super().__init__() + + self.config = config + + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + self.encoder = Blip2QFormerEncoder(config, + quant_config=quant_config, + cache_config=cache_config) + + def forward( + self, + query_embeds: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor, + ) -> torch.Tensor: + query_length = query_embeds.shape[1] + + embedding_output = self.layernorm(query_embeds) + embedding_output = self.dropout(embedding_output) + + sequence_output = self.encoder( + embedding_output, + encoder_hidden_states=encoder_hidden_states, + query_length=query_length, + ) + + return sequence_output + + +def get_blip2_image_feature_size(hf_config: Blip2Config) -> int: + return hf_config.num_query_tokens + + +def get_max_blip2_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(Blip2Config) + vision_config = hf_config.vision_config + + if isinstance(vision_config, Blip2VisionConfig): + return get_max_blip_image_tokens(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def dummy_seq_data_for_blip2( + hf_config: Blip2Config, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_blip2_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_data_for_blip2(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(Blip2Config) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + seq_data = dummy_seq_data_for_blip2( + hf_config, + seq_len, + num_images, + image_token_id=BLIP2_IMAGE_TOKEN_ID, + ) + + if isinstance(vision_config, Blip2VisionConfig): + mm_data = dummy_image_for_blip(vision_config, num_images) + + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + hf_config = ctx.get_hf_config(Blip2Config) + image_feature_size = get_blip2_image_feature_size(hf_config) + + # The original model places image tokens at the front + # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/blip_2/modeling_blip_2.py#L1514 + new_token_ids = [BLIP2_IMAGE_TOKEN_ID] * image_feature_size + new_token_ids += llm_inputs["prompt_token_ids"] + + new_prompt = llm_inputs.get("prompt") + if new_prompt is not None: + new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt + + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) +@INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, + config: Blip2Config, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_model = BlipVisionModel(config.vision_config) + + self.query_tokens = nn.Parameter( + torch.zeros(1, config.num_query_tokens, + config.qformer_config.hidden_size)) + + self.qformer = Blip2QFormerModel(config.qformer_config, + cache_config=cache_config, + quant_config=quant_config) + + self.language_projection = nn.Linear( + config.qformer_config.hidden_size, + config.text_config.hidden_size, + bias=True, + ) + + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Blip2ImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + + return Blip2ImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + + return Blip2ImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features(self, vision_model: BlipVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_model(pixel_values) + + return image_features + + def _process_image_pixels(self, + inputs: Blip2ImagePixelInputs) -> torch.Tensor: + assert self.vision_model is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_model, pixel_values) + + def _process_image_input(self, + image_input: Blip2ImageInputs) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + image_features = self._process_image_pixels(image_input) + + query_tokens = self.query_tokens.expand(image_features.shape[0], -1, + -1) + query_output = self.qformer( + query_embeds=query_tokens, + encoder_hidden_states=image_features, + ) + + return self.language_projection(query_output) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[SamplerOutput, IntermediateTensors]: + """Run forward pass for BLIP-2. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"Question: What's the content of the image? Answer:"`. + + Tokenizer outputs: + `[2, 45641, 35, 653, 18, 5, 1383, 9, 5, 2274, 116, 31652, 35]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + dummy tokens (denoted as `50265`), resulting in: + `[50265, ..., 50265, 2, 45641, 35, ..., 31652, 35]`. + + We insert 32 tokens since it corresponds to the number of query + embeddings outputted by the Q-Former and inputted to the language model. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + See also: + :class:`Blip2ImageInputs` + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + BLIP2_IMAGE_TOKEN_ID) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py new file mode 100644 index 0000000..b2c9e22 --- /dev/null +++ b/vllm/model_executor/models/bloom.py @@ -0,0 +1,362 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py +# Copyright 2023 The vLLM team. +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BLOOM model compatible with HuggingFace weights.""" +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import BloomConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BloomAttention(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.n_head + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + del position_ids # Unused. + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.dense(attn_output) + return output + + +class BloomMLP(nn.Module): + + def __init__( + self, + config: BloomConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + quant_config=quant_config, + ) + self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.dense_h_to_4h(x) + x = self.gelu_impl(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class BloomBlock(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + self.input_layernorm = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention(config, cache_config, + quant_config) + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.mlp = BloomMLP(config, quant_config) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attention_output = self.self_attention( + position_ids=position_ids, + hidden_states=layernorm_output, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + attention_output = attention_output + residual + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output) + residual + return output + + +class BloomModel(nn.Module): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = config.hidden_size + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.word_embeddings_layernorm = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_epsilon) + + # Transformer blocks + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: BloomBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h") + + # Final Layer Norm + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.word_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class BloomForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: BloomConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = BloomModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if name == "lm_head.weight": + continue + if not name.startswith("transformer."): + name = "transformer." + name + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py new file mode 100644 index 0000000..03c7419 --- /dev/null +++ b/vllm/model_executor/models/chameleon.py @@ -0,0 +1,1104 @@ +from functools import cached_property +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) + +import torch +import torch.nn.functional as F +from PIL import Image +from torch import nn +from transformers import ChameleonConfig, ChameleonVQVAEConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.utils import print_warning_once + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +# These configs are not part of the model config but the preprocessor +# and processor files, so we hardcode them in the model file for now. +CHAMELEON_CROP_SIZE_HEIGHT = CHAMELEON_CROP_SIZE_WIDTH = 512 +CHAMELEON_IMAGE_SEQ_LENGTH = 1024 +CHAMELEON_IMAGE_TOKEN_ID = 8711 +CHAMELEON_IMAGE_START_TOKEN_ID = 8197 +CHAMELEON_IMAGE_END_TOKEN_ID = 8196 +CHAMELEON_SEP_TOKEN_ID = 8710 + + +class ChameleonImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +def get_max_chameleon_image_tokens(ctx: InputContext): + return CHAMELEON_IMAGE_SEQ_LENGTH + + +def dummy_seq_data_for_chameleon( + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = CHAMELEON_IMAGE_SEQ_LENGTH + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_image_for_chameleon( + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = CHAMELEON_CROP_SIZE_WIDTH + height = CHAMELEON_CROP_SIZE_HEIGHT + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_data_for_chameleon(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + + seq_data = dummy_seq_data_for_chameleon( + seq_len, + num_images, + image_token_id=CHAMELEON_IMAGE_TOKEN_ID, + ) + + mm_data = dummy_image_for_chameleon(num_images) + return seq_data, mm_data + + +def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): + + """ + Processing input prompt to insert required tokens for image placeholder. + + See https://github.com/huggingface/transformers/blob/0fdea8607d7e01eb0e38a1ebeb7feee30a22f0cf/src/transformers/models/chameleon/processing_chameleon.py#L58 + """ # noqa + + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer) + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, + repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, + pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, + pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, + ) + + # Appending sep token for chat mode to follow default processor + # behavior + if new_prompt is not None: + new_prompt += tokenizer.sep_token + new_token_ids += [CHAMELEON_SEP_TOKEN_ID] + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class ChameleonLayerNorm(nn.LayerNorm): + + def __init__(self, hidden_size, *args, **kwargs): + super().__init__(hidden_size, *args, **kwargs) + self.normalized_shape = (hidden_size[-1], ) + + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.bias, + {"weight_loader": row_parallel_weight_loader}) + + def forward(self, hidden_states): + hidden_states = F.layer_norm(hidden_states, + self.normalized_shape, + None, + None, + eps=1e-5) + hidden_states = hidden_states * self.weight + self.bias + return hidden_states + + +# Copied from vllm.model_executor.models.llama.LlamaMLP -> ChameleonMLP +class ChameleonMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config) + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +# Modified from vllm.model_executor.models.llama.LlamaAttention -> ChameleonAttention #noqa +class ChameleonAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 4096, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + ) + self.q_norm = ChameleonLayerNorm((self.num_heads, self.head_dim)) + self.k_norm = ChameleonLayerNorm((self.num_kv_heads, self.head_dim)) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # reshape for layernorm + q = q.reshape(-1, self.num_heads, self.head_dim) + k = k.reshape(-1, self.num_kv_heads, self.head_dim) + q = self.q_norm(q) + k = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self._apply_qk_norm(q, k) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class ChameleonDecoderLayer(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = ChameleonAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + ) + self.mlp = ChameleonMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class ChameleonSwinDecoderLayer(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = ChameleonAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=False, + cache_config=cache_config, + ) + self.mlp = ChameleonMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEVectorQuantizer #noqa +class ChameleonVQVAEVectorQuantizer(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.num_embeddings = config.num_embeddings + self.embedding_dim = config.embed_dim + self.beta = getattr(config, "beta", 0.25) + + self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim) + self.re_embed = self.num_embeddings + + def forward(self, hidden_state: torch.Tensor): + hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous() + hidden_state_flattened = hidden_state.view(-1, self.embedding_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + distances = ( + torch.sum(hidden_state_flattened**2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight**2, dim=1) - + 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, + self.embedding.weight.transpose(0, 1))) + + min_encoding_indices = torch.argmin(distances, dim=1) + hidden_state_quant = self.embedding(min_encoding_indices).view( + hidden_state.shape) + + # compute loss for embedding + loss = torch.mean((hidden_state_quant.detach() - hidden_state)** + 2) + self.beta * torch.mean( + (hidden_state_quant - hidden_state.detach())**2) + + # preserve gradients + hidden_state_quant = hidden_state + (hidden_state_quant - + hidden_state).detach() + + # reshape back to match original input shape + hidden_state_quant = hidden_state_quant.permute(0, 3, 1, + 2).contiguous() + + return hidden_state_quant, loss, min_encoding_indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderConvDownsample #noqa +class ChameleonVQVAEEncoderConvDownsample(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + # no asymmetric padding in torch conv, must do it ourselves + hidden_states = F.pad(hidden_states, + pad=(0, 1, 0, 1), + mode="constant", + value=0) + hidden_states = self.conv(hidden_states) + return hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderResnetBlock #noqa +class ChameleonVQVAEEncoderResnetBlock(nn.Module): + + def __init__( + self, + config: ChameleonVQVAEConfig, + in_channels: int, + out_channels=None, + conv_shortcut=False, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None \ + else out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=32, + num_channels=out_channels, + eps=1e-6, + affine=True) + self.dropout = torch.nn.Dropout(config.dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + hidden_states *= torch.sigmoid(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + residual = self.conv_shortcut(residual) + else: + residual = self.nin_shortcut(residual) + + return residual + hidden_states + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoderAttnBlock #noqa +class ChameleonVQVAEEncoderAttnBlock(nn.Module): + + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=32, + num_channels=in_channels, + eps=1e-6, + affine=True) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm(hidden_states) + query_states = self.q(hidden_states) + key_states = self.k(hidden_states) + value_states = self.v(hidden_states) + + # compute attention + batch_size, channels, height, width = query_states.shape + query_states = query_states.reshape(batch_size, channels, + height * width).permute(0, 2, 1) + key_states = key_states.reshape(batch_size, channels, height * width) + attn_weights = torch.bmm(query_states, key_states) + attn_weights = attn_weights * (int(channels)**(-0.5)) + attn_weights = F.softmax(attn_weights, dim=2) + + # attend to values + value_states = value_states.reshape(batch_size, channels, + height * width) + attn_weights = attn_weights.permute(0, 2, 1) + attn_output = torch.bmm(value_states, + attn_weights).reshape(batch_size, channels, + height, width) + + attn_output = self.proj_out(attn_output) + return residual + attn_output + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAEEncoder #noqa +class ChameleonVQVAEEncoder(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + + self.num_resolutions = len(config.channel_multiplier) + self.num_res_blocks = config.num_res_blocks + base_channels = config.base_channels + resolution = config.resolution + in_channels = config.in_channels + double_latent = config.double_latent + latent_channels = config.latent_channels + channel_multiplier = config.channel_multiplier + + self.conv_in = torch.nn.Conv2d(in_channels, + base_channels, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_channel_multiplier = (1, ) + tuple(channel_multiplier) + self.in_channel_multiplier = in_channel_multiplier + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = base_channels * in_channel_multiplier[i_level] + block_out = base_channels * channel_multiplier[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_out, + )) + block_in = block_out + if (config.attn_resolutions is not None + and curr_res in config.attn_resolutions + and config.attn_type == "vanilla"): + attn.append(ChameleonVQVAEEncoderAttnBlock(block_in)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = ChameleonVQVAEEncoderConvDownsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + self.mid = nn.Module() + self.mid.block_1 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + self.mid.attn_1 = ChameleonVQVAEEncoderAttnBlock( + block_in) if config.attn_type == "vanilla" else nn.Identity() + self.mid.block_2 = ChameleonVQVAEEncoderResnetBlock( + config=config, + in_channels=block_in, + out_channels=block_in, + ) + + self.norm_out = torch.nn.GroupNorm(num_groups=32, + num_channels=block_in, + eps=1e-6, + affine=True) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * latent_channels if double_latent else latent_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, pixel_values: torch.Tensor): + pixel_values = pixel_values.to(self.conv_in.weight.dtype) + + # downsampling + hidden_states = [self.conv_in(pixel_values)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + hidden_state = self.down[i_level].block[i_block]( + hidden_states[-1], ) + if len(self.down[i_level].attn) > 0: + hidden_state = self.down[i_level].attn[i_block]( + hidden_state) + hidden_states.append(hidden_state) + if i_level != self.num_resolutions - 1: + hidden_states.append(self.down[i_level].downsample( + hidden_states[-1])) + + # middle + last_hidden_state = hidden_states[-1] + last_hidden_state = self.mid.block_1(last_hidden_state) + last_hidden_state = self.mid.attn_1(last_hidden_state) + last_hidden_state = self.mid.block_2(last_hidden_state) + + # end + last_hidden_state = self.norm_out(last_hidden_state) + last_hidden_state *= torch.sigmoid(last_hidden_state) + last_hidden_state = self.conv_out(last_hidden_state) + return last_hidden_state + + +# Adapted from transformers.models.chameleon.modeling_chameleon.ChameleonVQVAE #noqa +class ChameleonVQVAE(nn.Module): + + def __init__(self, config: ChameleonVQVAEConfig): + super().__init__() + self.encoder = ChameleonVQVAEEncoder(config) + self.quantize = ChameleonVQVAEVectorQuantizer(config) + self.quant_conv = torch.nn.Conv2d(config.latent_channels, + config.embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, + config.latent_channels, 1) + self.eval() # Chameleon's VQ model is frozen + + def encode( + self, pixel_values: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_states = self.encoder(pixel_values) + hidden_states = self.quant_conv(hidden_states) + quant, emb_loss, indices = self.quantize(hidden_states) + return quant, emb_loss, indices + + +# Copied from transformers.models.chameleon.modeling_chameleon.ChameleonImageVocabularyMapping #noqa +class ChameleonImageVocabularyMapping: + """ + A class for mapping discrete image tokens from VQGAN to BPE tokens. + """ + + def __init__(self, vocab_map: Dict[str, int]): + self.vocab_map = vocab_map + self.image_token_id = vocab_map.get("") + + @cached_property + def val2name(self): + return {v: k for k, v in self.vocab_map.items()} + + @cached_property + def image_tokens(self): + return sorted([ + val for name, val in self.vocab_map.items() + if name.startswith("IMGIMG") + ]) + + @cached_property + def bpe2img(self): + img_tkn_chr_mapping = {chr(ord("A") + i): str(i) for i in range(10)} + + def remap(old_name: str) -> str: + return "".join( + img_tkn_chr_mapping.get(c, c) + for c in old_name[len("IMGIMG"):-1]) + + return { + tok: int(remap(self.val2name[tok])) + for tok in self.image_tokens + } + + @cached_property + def img2bpe(self): + return {v: k for k, v in self.bpe2img.items()} + + @cached_property + def bpe2img_search_tensors(self): + return torch.tensor(sorted(self.bpe2img.keys())), torch.tensor( + sorted(self.bpe2img.values())) + + @cached_property + def img2bpe_mapping_tensor(self): + mapping = torch.zeros(max(self.img2bpe.keys()) + 1, dtype=torch.int) + for k, v in self.img2bpe.items(): + mapping[k] = v + return mapping + + def convert_img2bpe(self, img_batch: torch.Tensor) -> torch.Tensor: + device = img_batch.device + img_tokens = self.img2bpe_mapping_tensor[img_batch.to("cpu")] + return img_tokens.to(device) + + +class ChameleonModel(nn.Module): + + def __init__( + self, + config: ChameleonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + ) + self.vocabulary_mapping = ChameleonImageVocabularyMapping( + config.vocabulary_map) + decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ + else ChameleonSwinDecoderLayer + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.vqmodel = ChameleonVQVAE(config.vq_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def get_image_tokens(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Tokenizes images into discrete tokens with VQGAN module. Converts + obtained image tokens into BPE tokens and wraps with "boi" and "eoi" + special tokens. + """ + batch_size = pixel_values.shape[0] + _, _, image_toks = self.vqmodel.encode(pixel_values) + bpe_toks = self.vocabulary_mapping.convert_img2bpe(image_toks) + bpe_toks = bpe_toks.view(batch_size, -1) + return bpe_toks + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) +@INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) +class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__( + self, + config: ChameleonConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.multimodal_config = multimodal_config + self.model = ChameleonModel(config, cache_config, quant_config) + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + expected_dims = (3, CHAMELEON_CROP_SIZE_HEIGHT, + CHAMELEON_CROP_SIZE_WIDTH) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[ChameleonImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + + return ChameleonImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + + if intermediate_tensors is not None: + input_ids = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens( + image_input["data"].to(self.config.torch_dtype)) + image_token_id = self.model.vocabulary_mapping.image_token_id + special_image_mask = input_ids == image_token_id + image_tokens = image_tokens.to(input_ids.device, + input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, + image_tokens) + + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + # Disallow image tokens which does not include special + # begin-image and end-image tokens + if logits is not None: + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, image_tokens] = torch.finfo(logits.dtype).min + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + use_default_weight_loading = False + if "vqmodel" in name: + if self.model.vqmodel is not None: + # We only do sharding for language model and + # not vqvae for now. + use_default_weight_loading = True + else: + for (param_name, weight_name, + shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint (e.g. " + f"{name}), but not found the expected name in " + f"the model (e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + if use_default_weight_loading and name in params_dict: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py new file mode 100644 index 0000000..f26c9f9 --- /dev/null +++ b/vllm/model_executor/models/chatglm.py @@ -0,0 +1,671 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/GLM-4 +"""Inference-only ChatGLM model compatible with THUDM weights.""" +from argparse import Namespace +from array import array +from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict + +import torch +from PIL import Image +from torch import nn +from torch.nn import LayerNorm + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalInputs) +from vllm.multimodal.base import MultiModalData +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) +from vllm.transformers_utils.configs import ChatGLMConfig + +from .interfaces import SupportsLoRA, SupportsMultiModal + +logger = init_logger(__name__) + + +def calculate_image_placeholder(vision_config): + return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 + + +def mm_input_mapper_for_glmv( + ctx: InputContext, + data: MultiModalData[object], +) -> Dict: + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + if tokenizer is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + try: + raw_batch_data = tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "image": data + }], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True).data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + pixel_values = raw_batch_data['images'] + + return MultiModalInputs({'pixel_values': pixel_values}) + + +def merge_glm_vision_embeddings( + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + vision_embeddings: torch.Tensor, + boi_token_id: int, + eoi_token_id: int, +) -> torch.Tensor: + + boi_positions = (input_ids == boi_token_id).nonzero(as_tuple=True)[0] + eoi_positions = (input_ids == eoi_token_id).nonzero(as_tuple=True)[0] + + mask = torch.zeros_like(input_ids, dtype=torch.bool) + + for boi_pos, eoi_pos in zip(boi_positions, eoi_positions): + assert boi_pos < eoi_pos + mask[boi_pos:eoi_pos + 1] = True + inputs_embeds[mask] = vision_embeddings.view(-1, + vision_embeddings.shape[-1]) + return inputs_embeds + + +class GLMImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + +def get_max_glmv_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(ChatGLMConfig) + + vision_config = getattr(hf_config, 'vision_config', None) + if vision_config is None: + return 1 + elif isinstance(vision_config, dict): + return calculate_image_placeholder(vision_config) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def dummy_data_for_glmv( + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] +) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: + hf_config = ctx.get_hf_config(ChatGLMConfig) + vision_config = getattr(hf_config, 'vision_config', None) + + if vision_config is None: + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len) + seq_data = SequenceData(token_ids) + return seq_data, None + elif isinstance(vision_config, dict): + image_size = vision_config["image_size"] + image_placeholder_length = calculate_image_placeholder(vision_config) + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [hf_config.boi_token_id] + + [0] * image_placeholder_length + + [hf_config.eoi_token_id]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0] * (seq_len - image_placeholder_length - 2)) + seq_data = SequenceData(token_ids) + + mm_data = { + "image": Image.new("RGB", (image_size, image_size), color=0) + } + + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def find_all_positions(input_ids: List[int], target: int) -> List[int]: + return [index for index, value in enumerate(input_ids) if value == target] + + +def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): + hf_config = ctx.get_hf_config(ChatGLMConfig) + vision_config = getattr(hf_config, 'vision_config', None) + + if vision_config is None: + return llm_inputs + elif isinstance(vision_config, dict): + image_placeholder_length = calculate_image_placeholder(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + input_ids = llm_inputs.get("prompt_token_ids") + position_ids = llm_inputs.get("position_ids") + tokenizer = cached_get_tokenizer( + ctx.model_config.model, + trust_remote_code=ctx.model_config.trust_remote_code) + + try: + raw_batch_data = tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "image": llm_inputs['multi_modal_data']["image"], + "content": llm_inputs['prompt'] + }], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True).data + except Exception: + logger.error("Failed to process content (%s)", llm_inputs['prompt']) + raise + input_ids = raw_batch_data['input_ids'][0].tolist() + + if position_ids is None: + position_ids = list(range(len(input_ids))) + boi_token_id = hf_config.boi_token_id + eoi_token_id = hf_config.eoi_token_id + boi_positions = find_all_positions(input_ids, boi_token_id) + eoi_positions = find_all_positions(input_ids, eoi_token_id) + + assert len(boi_positions) == len(eoi_positions) + + new_input_ids = [] + new_position_ids = [] + final_processed_position = 0 + final_processed_position = 0 + + for boi_position, eoi_position in zip(boi_positions, eoi_positions): + assert boi_position < eoi_position + new_input_ids.extend(input_ids[final_processed_position:boi_position + + 1]) + new_position_ids.extend( + list(range(final_processed_position, boi_position + 1))) + new_input_ids.extend([input_ids[boi_position + 1]] * + image_placeholder_length) + new_position_ids.extend([boi_position + 1] * image_placeholder_length) + final_processed_position = eoi_position + + new_input_ids.extend(input_ids[final_processed_position:]) + new_position_ids.extend( + list(range(final_processed_position, len(input_ids)))) + + assert len(new_input_ids) == len(new_position_ids) + + llm_inputs["prompt_token_ids"] = new_input_ids + llm_inputs["position_ids"] = new_position_ids + return llm_inputs + + +class GLMAttention(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + ) + + # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 + rope_ratio = getattr(config, "rope_ratio", 1.0) + max_positions = getattr(config, "seq_length", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim // 2, + max_position=max_positions, + base=10000 * rope_ratio, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + context_layer = self.attn( + q, + k, + v, + kv_cache, + attn_metadata, + ) + attn_output, _ = self.dense(context_layer) + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. + self.dense_h_to_4h = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=config.add_bias_linear, + quant_config=quant_config, + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + quant_config=quant_config, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = GLMAttention(config, cache_config, quant_config) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GLMMLP(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList([ + GLMBlock(config, cache_config, quant_config) + for i in range(self.num_layers) + ]) + + if self.post_layer_norm: + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i in range(self.num_layers): + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + position_ids=position_ids, + kv_cache=kv_caches[i], + attn_metadata=attn_metadata, + ) + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class ChatGLMModel(nn.Module): + + def __init__( + self, + config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.config = config + + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config) + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config, cache_config, quant_config) + + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size, + quant_config=quant_config) + + vision_config_flag = getattr(config, 'vision_config', None) + if vision_config_flag is not None: + self.vision_config = Namespace(**config.vision_config) + self.vision = EVA2CLIPModel(self.config, quant_config) + else: + self.vision = None + + def _parse_and_validate_image_input( + self, **kwargs: object) -> GLMImagePixelInputs: + + pixel_values = kwargs.pop("pixel_values", None) + if pixel_values is not None and self.vision is not None: + if isinstance(pixel_values, torch.Tensor): + if pixel_values.ndim > 2: + pixel_values = torch.concat(list(pixel_values)) + elif isinstance(pixel_values, list): + return torch.concat(pixel_values) + else: + raise TypeError("""pixel_values must be a torch.Tensor + or a list of torch.Tensor + """) + return GLMImagePixelInputs(pixel_values=pixel_values) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> torch.Tensor: + + inputs_embeds = self.embedding(input_ids) + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input["pixel_values"] is not None: + pixel_values = image_input["pixel_values"].to( + dtype=inputs_embeds.dtype) + image_embeds = self.vision(pixel_values) + + boi_token_id = self.config.boi_token_id + eoi_token_id = self.config.eoi_token_id + + inputs_embeds = merge_glm_vision_embeddings( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + vision_embeddings=image_embeds, + boi_token_id=boi_token_id, + eoi_token_id=eoi_token_id) + + # Run encoder. + hidden_states = self.encoder( + hidden_states=inputs_embeds, + position_ids=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: ChatGLMConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + + self.config = config + self.lora_config = lora_config + self.multimodal_config = multimodal_config + + self.quant_config = quant_config + self.max_position_embeddings = getattr(config, "max_sequence_length", + 8192) + self.transformer = ChatGLMModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.transformer.output_layer.weight = ( + self.transformer.embedding.weight) + self.lm_head = self.transformer.output_layer + self.logits_processor = LogitsProcessor(config.padded_vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, **kwargs) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # Merge two ColumnParallelLinear into one MergedColumnParallelLinear + merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = { + "transformer.vision.linear_proj.merged_proj.weight": { + "transformer.vision.linear_proj.gate_proj.weight": None, + "transformer.vision.linear_proj.dense_h_to_4h.weight": None, + } + } + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + is_weight_to_be_merge = False + for _, merged_weight_dict in merged_weights_dict.items(): + if name in merged_weight_dict: + assert merged_weight_dict[name] is None + merged_weight_dict[name] = loaded_weight + is_weight_to_be_merge = True + if is_weight_to_be_merge: + continue + if "rotary_pos_emb.inv_freq" in name: + continue + if "word_embeddings" in name: + name = name.replace(".word_embeddings", "") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + for combined_name, merged_weight_dict in merged_weights_dict.items(): + if combined_name in params_dict: + param = params_dict[combined_name] + combined_weight = torch.cat(list(merged_weight_dict.values()), + dim=0) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, combined_weight) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py new file mode 100644 index 0000000..0fb0bf2 --- /dev/null +++ b/vllm/model_executor/models/clip.py @@ -0,0 +1,474 @@ +"""Minimal implementation of CLIPVisionModel intended to be only used +within a vision language model.""" +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from transformers import CLIPVisionConfig +from transformers.models.clip.modeling_clip import CLIPSdpaAttention + +from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import SequenceData + +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + +USE_XFORMERS_OPS = True + + +def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_clip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_clip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int: + return get_clip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + 1 + + +def get_max_clip_image_tokens(hf_config: CLIPVisionConfig) -> int: + return get_clip_image_feature_size(hf_config) + + +def dummy_seq_data_for_clip( + hf_config: CLIPVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_clip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_image_for_clip( + hf_config: CLIPVisionConfig, + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_video_for_clip( + hf_config: CLIPVisionConfig, + num_frames: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + pil_frame = dummy_image_for_clip( + hf_config, + num_images=1, + image_width_override=image_width_override, + image_height_override=image_height_override) + np_frame = np.array(pil_frame["image"]) + mm_data_per_video = np.repeat([np_frame], num_frames, axis=0) + mm_data = {"video": mm_data_per_video} + return mm_data + + +def input_processor_for_clip( + model_config: ModelConfig, + hf_config: CLIPVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[Union[int, List[int]]] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_clip_image_feature_size(hf_config) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa +class CLIPVisionEmbeddings(nn.Module): + + def __init__(self, config: CLIPVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.num_patches = get_clip_num_patches(image_size=self.image_size, + patch_size=self.patch_size) + self.num_positions = self.num_patches + 1 + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + self.register_buffer("position_ids", + torch.arange(self.num_positions).expand((1, -1)), + persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + + class_embeds = self.class_embedding.expand(batch_size, 1, -1) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + embeddings = embeddings + self.position_embedding(self.position_ids) + + return embeddings + + +class CLIPParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + query_states = query_states.transpose(1,2) + key_states = key_states.transpose(1,2) + value_states = value_states.transpose(1,2) + attn = (query_states @ (key_states * self.scale).permute(0,1,3,2)).float() + attn = torch.softmax(attn, dim=-1).to(value_states.dtype) + out = attn @ value_states + out = out.transpose(1,2).reshape(bsz, tgt_len, -1) + # TODO: use F.attention when supported + # out = xops.memory_efficient_attention_forward(query_states, + # key_states, + # value_states, + # p=self.dropout, + # scale=self.scale) + attn_output, _ = self.out_proj(out) + + return attn_output, None + + +class CLIPMLP(nn.Module): + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class CLIPEncoderLayer(nn.Module): + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = CLIPParallelAttention(config, + quant_config=quant_config) + else: + self.self_attn = CLIPSdpaAttention(config) + self.layer_norm1 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = CLIPMLP(config, quant_config=quant_config) + self.layer_norm2 = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self + attention layers. Each layer is a [`CLIPEncoderLayer`]. + + Args: + config: CLIPConfig + """ + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList([ + CLIPEncoderLayer(config=config, quant_config=quant_config) + for _ in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class CLIPVisionTransformer(nn.Module): + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = CLIPVisionEmbeddings(config) + + # NOTE: This typo of "layrnorm" is not fixed on purpose to match + # the original transformers code and name of the model weights. + self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override) + + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None + + def forward( + self, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + hidden_states = self.embeddings(pixel_values) + hidden_states = self.pre_layrnorm(hidden_states) + hidden_states = self.encoder(inputs_embeds=hidden_states) + + if self.post_layernorm is None: + return hidden_states + + return self.post_layernorm(hidden_states) + + +class CLIPVisionModel(nn.Module): + + config_class = CLIPVisionConfig + main_input_name = "pixel_values" + + def __init__(self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None): + super().__init__() + + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + + self.vision_model = CLIPVisionTransformer( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + return self.vision_model(pixel_values) + + @property + def device(self): + return next(self.parameters()).device + + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is not needed in CLIPVisionModel + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py new file mode 100644 index 0000000..578cd2f --- /dev/null +++ b/vllm/model_executor/models/commandr.py @@ -0,0 +1,441 @@ +# coding=utf-8 +# Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is based on the LLama model definition file in transformers +"""PyTorch Cohere model.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers import CohereConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name, + row_parallel_weight_loader) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +@torch.compile +def layer_norm_func(hidden_states, weight, variance_epsilon): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + + variance_epsilon) + hidden_states = weight.to(torch.float32) * hidden_states + return hidden_states.to(input_dtype) + + +class LayerNorm(nn.Module): + + def __init__(self, param_shape=None, eps=1e-5): + super().__init__() + self.weight = nn.Parameter(torch.ones(param_shape)) + self.variance_epsilon = eps + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + + def forward(self, hidden_states, residuals=None): + hidden_states = layer_norm_func(hidden_states, self.weight, + self.variance_epsilon) + return hidden_states, residuals + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere +class CohereMLP(nn.Module): + + def __init__( + self, + config: CohereConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class CohereAttention(nn.Module): + + def __init__( + self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + tp_size = get_tensor_model_parallel_world_size() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.max_position_embeddings = getattr( + config, "model_max_length", None) or getattr( + config, "max_position_embeddings", 8192) + self.rope_theta = config.rope_theta + self.rope_scaling = getattr(config, "rope_scaling", None) + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + rope_scaling=self.rope_scaling, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + if self.use_qk_norm: + self.q_norm = LayerNorm(param_shape=(self.num_heads, + self.head_dim), + eps=config.layer_norm_eps) + self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, + self.head_dim), + eps=config.layer_norm_eps) + + def _apply_qk_norm(self, q, k): + q = q.view(*q.shape[:-1], -1, self.head_dim) + k = k.view(*k.shape[:-1], -1, self.head_dim) + q, _ = self.q_norm(q) + k, _ = self.k_norm(k) + q = q.view(*q.shape[:-2], -1) + k = k.view(*k.shape[:-2], -1) + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class CohereDecoderLayer(nn.Module): + + def __init__(self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = CohereAttention(config, + cache_config, + quant_config=quant_config) + + self.mlp = CohereMLP(config, quant_config=quant_config) + self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states_attention = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states_mlp = self.mlp(hidden_states) + # Add everything together + hidden_states = residual + hidden_states_attention + hidden_states_mlp + + return hidden_states, residual + + +class CohereModel(nn.Module): + + def __init__( + self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CohereDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + self.norm = LayerNorm(param_shape=(config.hidden_size), + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens" + ] + embedding_modules = {"embed_tokens": "input_embeddings"} + embedding_padding_modules = [] + + def __init__( + self, + config: CohereConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + # currently all existing command R models have `tie_word_embeddings` + # enabled + assert config.tie_word_embeddings + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.quant_config = quant_config + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=config.logit_scale) + self.model = CohereModel(config, + cache_config, + quant_config, + lora_config=lora_config) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + is_not_lora = hasattr(self.model.embed_tokens, 'weight') + if is_not_lora: + logits = self.logits_processor(self.model.embed_tokens, + hidden_states, sampling_metadata) + else: + logits = self.logits_processor(self.model.embed_tokens.base_layer, + hidden_states, sampling_metadata) + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py new file mode 100644 index 0000000..aae7ab7 --- /dev/null +++ b/vllm/model_executor/models/dbrx.py @@ -0,0 +1,439 @@ +# coding=utf-8 +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.dbrx import DbrxConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class DbrxRouter(nn.Module): + """A Router implementation for DBRX that returns logits for each expert + per token. + """ + + def __init__( + self, + config: DbrxConfig, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.ffn_config.moe_num_experts + self.d_model = config.d_model + self.layer = ReplicatedLinear( + self.d_model, + self.num_total_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + router_logits, _ = self.layer(hidden_states) + return router_logits + + +class DbrxExperts(FusedMoE): + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__( + num_experts=config.ffn_config.moe_num_experts, + top_k=config.ffn_config.moe_top_k, + hidden_size=config.d_model, + intermediate_size=config.ffn_config.ffn_hidden_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=get_tensor_model_parallel_world_size(), + ) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.d_model = config.d_model + self.intermediate_size = (self.config.ffn_config.ffn_hidden_size // + self.tp_size) + + # Define custom weight loader for dbrx model + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + # DBRX uses GLU for each experts. + # GLU has 3 linear layers: w1, v1 and w2. + if weight_name.endswith("w1"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :] + if weight_name.endswith("v1"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ) + param_data[:, + shard_size:2 * shard_size, :] = loaded_weight[:, + shard, :] + if weight_name.endswith("w2"): + loaded_weight = torch.reshape( + loaded_weight, + [-1, self.intermediate_size * self.tp_size, self.d_model], + ).transpose(1, 2) + param_data[:] = loaded_weight[:, :, shard] + + +class DbrxMoE(nn.Module): + """A tensor-parallel MoE implementation for DBRX. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + config: DbrxConfig, + quant_config: Optional[QuantizationConfig] = None, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.d_model = config.d_model + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.router = DbrxRouter(config, self.params_dtype) + + self.experts = DbrxExperts(config=config, + quant_config=quant_config, + params_dtype=self.params_dtype) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.d_model) + # router_logits: (num_tokens, n_experts) + router_logits = self.router(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class DbrxAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.total_num_kv_heads = config.attn_config.kv_n_heads + self.clip_qkv = config.attn_config.clip_qkv + self.rope_theta = config.attn_config.rope_theta + self.max_position = config.max_seq_len + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + self.tp_size = tp_world_size + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + hidden_states, _ = self.out_proj(attn_output) + return hidden_states + + +class DbrxFusedNormAttention(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.attn = DbrxAttention(config, cache_config, quant_config) + self.norm_1 = nn.LayerNorm(self.d_model) + self.norm_2 = nn.LayerNorm(self.d_model) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + x + residual = hidden_states + hidden_states = self.norm_2(hidden_states) + return hidden_states, residual + + +class DbrxBlock(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, + quant_config) + self.ffn = DbrxMoE(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states, residual = self.norm_attn_norm( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = self.ffn(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DbrxModel(nn.Module): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: DbrxBlock(config, cache_config, quant_config), + prefix=f"{prefix}.blocks", + ) + self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) + for module in self.modules(): + if hasattr(module, "bias") and isinstance(module.bias, + nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + block = self.blocks[i] + hidden_states = block( + position_ids, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class DbrxForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: DbrxConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + if config.tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models.") + self.quant_config = quant_config + self.unpadded_vocab_size = config.vocab_size + self.transformer = DbrxModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.d_model, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + expert_params_mapping = [( + "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight", + f"mlp.{weight_name}", + ) for weight_name in ["w1", "v1", "w2"]] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + for param_name, weight_name in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, weight_name) + break + else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py new file mode 100644 index 0000000..7ed2b96 --- /dev/null +++ b/vllm/model_executor/models/decilm.py @@ -0,0 +1,129 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 DeciAI Research Team. All rights reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeciLM model compatible with HuggingFace weights.""" + +from typing import Iterable, Optional, Tuple + +import torch +from transformers import LlamaConfig + +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaForCausalLM + +from .utils import is_pp_missing_parameter + + +class DeciLMForCausalLM(LlamaForCausalLM): + """ + Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct. + Based on the llama executor. + + The main difference is that DeciLM uses Variable Grouped Query Attention. + The constant number of GQA heads in the decoder is overridden with a value + per layer. + + Usually, in the HuggingFace implementation, instead of + "config.num_key_value_heads", we use + "config.num_key_value_heads_per_layer[i]" which varies. + + Currently, PagedAttention does not work well with variable GQA, so we + normalize the weights upon loading, and use uniform GQA with the max value + instead. + """ + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + config.num_key_value_heads = max(config.num_key_value_heads_per_layer) + delattr(config, "num_key_value_heads_per_layer") + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "k_proj" in name or "v_proj" in name: + loaded_weight = self._degroup_weight(loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: + hidden_size = self.config.hidden_size + head_size = self.config.hidden_size // self.config.num_attention_heads + target_num_kv_heads = self.config.num_key_value_heads + num_kv_heads = loaded_weight.shape[0] // head_size + n_repeats = target_num_kv_heads / num_kv_heads + assert n_repeats == int(n_repeats) + + n_repeats = int(n_repeats) + loaded_weight = loaded_weight.view(num_kv_heads, head_size, + hidden_size) + loaded_weight = torch.repeat_interleave(loaded_weight, + repeats=n_repeats, + dim=0) + loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size, + hidden_size) + + return loaded_weight diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py new file mode 100644 index 0000000..5b4db8f --- /dev/null +++ b/vllm/model_executor/models/deepseek.py @@ -0,0 +1,480 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Deepseek model.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + + self.experts = nn.ModuleList([ + DeepseekMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() + + self.gate = ReplicatedLinear(config.hidden_size, + self.n_routed_experts, + bias=False, + quant_config=None) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekDecoderLayer(config, + int(prefix.split(".")[-1]), + cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = DeepseekModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py new file mode 100644 index 0000000..702be7b --- /dev/null +++ b/vllm/model_executor/models/deepseek_v2.py @@ -0,0 +1,617 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV2 model.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class DeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV2MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.experts = FusedMoE(num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV2Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + rope_scaling['type'] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # self.attn = Attention(self.num_heads, + # self.qk_head_dim, + # self.scaling, + # num_kv_heads=self.num_heads) + + # TODO, support head_size 192 + self.attn = Attention(self.num_local_heads, + 256, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], + value=0).view(-1, + self.num_local_heads * 256) + k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], + value=0).view(-1, + self.num_local_heads * 256) + v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], + value=0).view(-1, + self.num_local_heads * 256) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV2DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.self_attn = DeepseekV2Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer( + config, + prefix, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV2ForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = DeepseekV2Model(config, + cache_config, + quant_config, + prefix="model") + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py new file mode 100644 index 0000000..13811d3 --- /dev/null +++ b/vllm/model_executor/models/eagle.py @@ -0,0 +1,170 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.eagle import EAGLEConfig + + +class EAGLE(nn.Module): + """This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077 + Reference implementation: https://github.com/SafeAILab/EAGLE + + Differences from reference implementation: + 1. In reference, LlamaDecoderLayer implementation doesn't have + input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427) + but we do as HF implementation also does. + 2. We allow any decoder layer to be used in EAGLE whereas in reference + decoder layer is fixed to be LlamaDecoderLayer. + 3. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" + + def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None: + super().__init__() + self.config = config + + architectures = getattr(self.config.model, "architectures", []) + model_cls, _ = ModelRegistry.resolve_model_cls(architectures) + + self.model = model_cls(self.config.model, *args, **kwargs) + self.fc = nn.Linear(config.model.hidden_size * 2, + config.model.hidden_size, + bias=getattr(self.config, "bias", False)) + + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + @property + def sampler(self): + return self.model.sampler + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + + tok_embeds = self.model.model.embed_tokens(input_ids) + inputs_embeds = self.fc( + torch.cat([tok_embeds, previous_hidden_states], dim=-1)) + + inputs_embeds[positions == 0] = 0 # masking inputs at position=0 + + hidden_states = self.model.model( + input_ids=None, + inputs_embeds=inputs_embeds, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + + if self.token_map is not None: + _logits = logits + logits = -torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype) + + logits[..., self.token_map] = _logits + + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B + # due to missing lm_head weights and its config being that of a + # Llama model. Here's a compatible version with the same weights: + # https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm + # Also, here's an example script for converting trained EAGLE + # checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d + model_weights = {} + for name, loaded_weight in weights: + if name == "token_map": + if self.config.truncated_vocab_size < self.config.vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name.startswith("fc.weight"): + weight_loader = getattr(self.fc.weight, "weight_loader", + default_weight_loader) + weight_loader(self.fc.weight, loaded_weight) + elif name.startswith("fc.bias"): + if self.fc.bias is not None: + weight_loader = getattr(self.fc.bias, "weight_loader", + default_weight_loader) + weight_loader(self.fc.bias, loaded_weight) + else: + raise ValueError("Found bias in the loaded weights " + "but the model config doesn't have bias") + elif name.startswith("model.lm_head.") or name.startswith( + "model.model."): + model_weights[name.split("model.", 1)[-1]] = loaded_weight + elif name.startswith("lm_head.") or name.startswith("model."): + model_weights[name] = loaded_weight + else: + model_weights[f"model.{name}"] = loaded_weight + + lm_head_weight = model_weights.pop("lm_head.weight") + + if self.token_map is not None and\ + lm_head_weight.shape[0] > self.token_map.shape[0]: + + lm_head_weight = lm_head_weight[self.token_map] + + weight_loader = getattr(self.lm_head.weight, "weight_loader", + default_weight_loader) + weight_loader(self.lm_head.weight, lm_head_weight) + + self.model.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py new file mode 100644 index 0000000..dfb8fe5 --- /dev/null +++ b/vllm/model_executor/models/exaone.py @@ -0,0 +1,606 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py +# Copyright 2024 The LG U+ CTO AI Tech Lab. +# Copyright 2021 The LG AI Research EXAONE Lab +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.exaone import ExaoneConfig +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class ExaoneGatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneBlockAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ExaoneAttention( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + return self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + +class ExaoneDecoderLayer(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.attn = ExaoneBlockAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ExaoneModel(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.wte = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.wte = PPMissingLayer() + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.h", + ) + if get_pp_group().is_last_rank: + self.ln_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + else: + self.ln_f = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "c_fc_0", + "c_fc_1", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "out_proj", + "gate_up_proj", + "c_proj", + "wte", + "lm_head", + ] + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "c_fc_0": ("gate_up_proj", 0), + "c_fc_1": ("gate_up_proj", 1), + } + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.transformer = ExaoneModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.transformer.h[layer_idx], nn.Identity): + layer_self_attn = self.transformer.h[layer_idx].attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py new file mode 100644 index 0000000..467a335 --- /dev/null +++ b/vllm/model_executor/models/falcon.py @@ -0,0 +1,508 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from transformers import FalconConfig as HF_FalconConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import RWConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +FalconConfig = Union[HF_FalconConfig, RWConfig] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + + return slopes + + +class FalconAttention(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + + if self.new_decoder_architecture: + self.total_num_kv_heads = config.num_kv_heads + elif self.multi_query: + self.total_num_kv_heads = 1 + else: + self.total_num_kv_heads = self.total_num_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config, + reduce_results=self.reduce_row_parallel_results) + + self.use_rotary = config.rotary + self.use_alibi = config.alibi + assert not (self.use_rotary and self.use_alibi), ( + "Rotary and alibi are mutually exclusive.") + + if self.use_rotary: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config) + elif self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * + self.inv_norm_factor) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + self.attn = Attention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes, + quant_config=quant_config) + else: + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_rotary: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output, bias = self.dense(attn_output) + return attn_output, bias + + +class FalconMLP(nn.Module): + + def __init__( + self, + config: FalconConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + quant_config=quant_config) + self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + bias=config.bias, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results, + quant_config=quant_config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. + x, bias = self.dense_h_to_4h(x) + if bias is not None: + x += bias + x = self.act(x) + x, bias = self.dense_4h_to_h(x) + return x, bias + + +class FalconDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = FalconAttention(config, cache_config, + quant_config) + self.mlp = FalconMLP(config, quant_config) + self.config = config + + if (config.num_ln_in_parallel_attn is None + and config.new_decoder_architecture): + config.num_ln_in_parallel_attn = 2 + + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + if config.num_ln_in_parallel_attn == 2: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + + if self.config.num_ln_in_parallel_attn == 2: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention( + positions=positions, + hidden_states=attention_layernorm_out, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + if self.reduce_row_parallel_results and attention_bias is not None: + attention_output += attention_bias + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual += attention_output + mlp_layernorm_out = self.post_attention_layernorm(residual) + + if (self.config.new_decoder_architecture and self.config.parallel_attn + and self.config.num_ln_in_parallel_attn == 1): + mlp_layernorm_out = attention_layernorm_out + + # MLP. + mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) + if self.reduce_row_parallel_results and mlp_bias is not None: + mlp_output += mlp_bias + + if not self.reduce_row_parallel_results: + # When MLP and Attention layers are parallel, we can use + # only one all-reduce operator to reduce the results from + # both MLP and Attention layers. + mlp_output += attention_output + mlp_output = tensor_model_parallel_all_reduce(mlp_output) + if attention_bias is not None: + mlp_output += attention_bias + if mlp_bias is not None: + mlp_output += mlp_bias + + output = mlp_output + residual + return output + + +class FalconModel(nn.Module): + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + + # Transformer blocks + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: FalconDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.h") + + # Final Layer Norm + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.word_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class FalconForCausalLM(nn.Module, SupportsPP): + + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = {} + default_bitsandbytes_target_modules = [ + ".query_key_value.", + ".dense.", + ".dense_h_to_4h.", + ".dense_4h_to_h.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".dense_4h_to_h.", ".dense."] + + def __init__( + self, + config: FalconConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = FalconModel(config, cache_config, quant_config) + # only Falcon-11B doesn't share lm_head weight with word embeddings + # and previous Falcon model doesn't have tie_word_embeddings config + # so we set tie_word_embeddings to True by default + self.tie_word_embeddings = (config.tie_word_embeddings + if config.tie_word_embeddings is not None + else True) + if self.tie_word_embeddings: + self.lm_head = self.transformer.word_embeddings + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + total_num_heads = self.config.num_attention_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if name == "lm_head.weight" and self.tie_word_embeddings: + # Falcon uses tied embeddings except Falcon-11b. + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + if "query_key_value" in name: + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py new file mode 100644 index 0000000..62a1b1f --- /dev/null +++ b/vllm/model_executor/models/fuyu.py @@ -0,0 +1,351 @@ +# coding=utf-8 +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/fuyu/modeling_fuyu.py +# Copyright 2023 The vLLM team. +# Copyright 2023 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Fuyu model.""" +import math +from array import array +from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from PIL import Image +from transformers import FuyuConfig, FuyuImageProcessor + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.persimmon import PersimmonForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import AutoWeightsLoader, flatten_bn, merge_multimodal_embeddings + +# Cannot find the following 2 numbers from hf config. +_IMAGE_TOKEN_ID = 71011 +_NEWLINE_TOKEN_ID = 71019 + +MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080 +MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920 + + +class FuyuImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: + (batch_size, num_patches, patch_size_x * patch_size_y * num_channels) + """ + + +def _calculate_num_image_tokens( + height: int, + width: int, +) -> Tuple[int, int]: + """ + calculate number of image tokens needed for a given image size + The expected Fuyu image prompts is in format: + (image_token * ncols + newline_token) * nrows + args: + image_size: Tuple[int, int] - (width, height) of the image + returns: + ncols: int - number of image tokens in x direction + nrows: int - number of image tokens in y direction + """ + ncol = math.ceil(width / 30) + nrow = math.ceil(height / 30) + return ncol, nrow + + +def get_max_fuyu_image_feature_size(): + + return _calculate_num_image_tokens( + height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + ) + + +def get_max_fuyu_image_tokens(ctx: InputContext): + ncol, nrow = get_max_fuyu_image_feature_size() + return (ncol + 1) * nrow + + +def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int): + ncol, nrow = get_max_fuyu_image_feature_size() + image_feature_size = get_max_fuyu_image_tokens(ctx) + + image_token_ids = ( + array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol + + array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - image_feature_size * num_images) + return SequenceData(token_ids) + + +def dummy_image_for_fuyu( + num_images: int, + *, + image_width: int, + image_height: int, +): + image = Image.new("RGB", (image_width, image_height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_data_for_fuyu(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + seq_data = dummy_seq_data_for_fuyu(ctx, seq_len, num_images) + mm_data = dummy_image_for_fuyu(num_images, + image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT) + return seq_data, mm_data + + +def _fuyu_image_preprocess(image_processor: FuyuImageProcessor, + data: Image.Image): + image_encoding = image_processor.preprocess(data, return_tensors="pt") + batch_images = torch.stack([img[0] for img in image_encoding["images"] + ]).unsqueeze(1) + image_unpadded_heights = torch.tensor( + image_encoding["image_unpadded_heights"]) + image_unpadded_widths = torch.tensor( + image_encoding["image_unpadded_widths"]) + + batch_size = len(image_encoding["images"]) + image_present = torch.ones(batch_size, 1, 1) + model_image_input = image_processor.preprocess_with_tokenizer_info( + image_input=batch_images, + image_present=image_present, + image_unpadded_h=image_unpadded_heights, + image_unpadded_w=image_unpadded_widths, + image_placeholder_id=_IMAGE_TOKEN_ID, + image_newline_id=_NEWLINE_TOKEN_ID, + variable_sized=True, + ) + return model_image_input + + +def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + image_data = multi_modal_data["image"] + new_multi_modal_data = {} + # process image data + if isinstance(image_data, Image.Image): + # Fuyu's image_processor can also finish token padding + image_processor: FuyuImageProcessor = cached_get_image_processor( + model_config.model) + + model_image_input = _fuyu_image_preprocess(image_processor, image_data) + image_patches = torch.cat([ + image_patch[0] + for image_patch in model_image_input["image_patches"] + ]) + new_multi_modal_data["image"] = image_patches + + elif isinstance(image_data, torch.Tensor): + raise NotImplementedError("Embeddings input is not supported yet") + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + # process prompts + prompt = llm_inputs.get("prompt") + prompt_token_ids = llm_inputs["prompt_token_ids"] + tokenizer = cached_get_tokenizer(model_config.model) + # dim0 is batch_size, dim1 is subseq_size which will always be 1 + image_input_ids: List[List[ + torch.Tensor]] = model_image_input["image_input_ids"] + image_input_ids = image_input_ids[0][0].tolist() + bos_token = tokenizer.encode("", add_special_tokens=False)[1:] + boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:] + + new_prompt = prompt + "\x04" + new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[ + 1:] + boa_token + + return LLMInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=new_multi_modal_data) + + +def input_mapper_for_fuyu(ctx: InputContext, data: object): + model_config = ctx.model_config + if isinstance(data, Image.Image): + # Fuyu's image_processor can also finish token padding + image_processor: FuyuImageProcessor = cached_get_image_processor( + model_config.model) + + model_image_input = _fuyu_image_preprocess(image_processor, data) + data = torch.stack([ + image_patch[0] + for image_patch in model_image_input["image_patches"] + ]) + + # image has been processed with prompt in input processor + return MultiModalInputs({"pixel_values": data}) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) +@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) +class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, + config: FuyuConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.multimodal_config = multimodal_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.text_config.vocab_size + self.image_token_id = _IMAGE_TOKEN_ID + self.image_feature_size = config.patch_size**2 * config.num_channels + + self.vision_embed_tokens = ColumnParallelLinear( + self.image_feature_size, + config.hidden_size, + quant_config=quant_config, + gather_output=True, + ) + self.language_model = PersimmonForCausalLM(config.text_config, + cache_config=cache_config, + quant_config=quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.patch_size + num_channels = self.config.num_channels + expected_dims = num_channels * h * w + + def _validate_shape(d: torch.Tensor): + actual_dims = d.size(-1) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data.to(self.vision_embed_tokens.weight.dtype) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[FuyuImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image patches. " + f"Got type: {type(pixel_values)}") + + return FuyuImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + return None + + def _process_image_input( + self, image_input: FuyuImagePixelInputs) -> torch.Tensor: + + assert self.vision_embed_tokens is not None + vision_embeddings, _ = self.vision_embed_tokens(image_input["data"]) + return vision_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ): + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.embed_tokens( + input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) + + else: + inputs_embeds = None + + hidden_states = self.language_model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.language_model.logits_processor( + self.language_model.lm_head, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.language_model.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py new file mode 100644 index 0000000..91e556d --- /dev/null +++ b/vllm/model_executor/models/gemma.py @@ -0,0 +1,455 @@ +# coding=utf-8 +# Copyright 2023 The vLLM team. +# Copyright (c) Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model compatible with HuggingFace weights.""" +from functools import lru_cache +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import GemmaConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +logger = init_logger(__name__) + + +@lru_cache(maxsize=None) +def _get_gemma_act_fn( + hidden_act: Optional[str], + hidden_activation: Optional[str], +) -> nn.Module: + if hidden_activation is None: + if hidden_act is not None: + logger.warning( + "Gemma's activation function was incorrectly set to exact GeLU " + "in the config JSON file when it was initially released. " + "Changing the activation function to approximate GeLU " + "(`gelu_pytorch_tanh`). If you want to use the legacy " + "`%s`, edit the config JSON to set " + "`hidden_activation=%s` instead of `hidden_act`. " + "See https://github.com/huggingface/transformers/pull/29402 " + "for more details.", hidden_act, hidden_act) + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu_pytorch_tanh": + return GeluAndMul(approximate="tanh") + elif hidden_activation == "gelu": + return GeluAndMul(approximate="none") + else: + raise ValueError(f"Activation function {hidden_act} is not " + "supported for Gemma models.") + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: Optional[str] = None, + hidden_activation: Optional[str] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GemmaAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int = 8192, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GemmaDecoderLayer(nn.Module): + + def __init__( + self, + config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + ) + self.mlp = GemmaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=getattr(config, "hidden_activation", None), + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class GemmaModel(nn.Module): + + def __init__( + self, + config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: GemmaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = GemmaModel(config, cache_config, quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # lm_head is not used in vllm as it is tied with embed_token. + # To prevent errors, skip loading lm_head.weight. + if "lm_head.weight" in name: + continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py new file mode 100644 index 0000000..bcb03ef --- /dev/null +++ b/vllm/model_executor/models/gemma2.py @@ -0,0 +1,463 @@ +# coding=utf-8 +# Copyright 2024 The vLLM team. +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Gemma2Config + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.layernorm import GemmaRMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +logger = init_logger(__name__) + + +class Gemma2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + hidden_activation: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): + raise ValueError( + "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`.") + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma2Attention(nn.Module): + + def __init__(self, + layer_idx: int, + config: Gemma2Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + rope_theta: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.query_pre_attn_scalar**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + + # FIXME(woosuk): While Gemma 2 uses sliding window attention for every + # odd layer, vLLM currently ignores it and uses global attention for + # all layers. + use_sliding_window = (layer_idx % 2 == 1 + and config.sliding_window is not None) + del use_sliding_window # Unused. + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Gemma2DecoderLayer(nn.Module): + + def __init__( + self, + layer_idx: int, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Gemma2Attention( + layer_idx=layer_idx, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=config.attn_logit_softcapping, + ) + self.hidden_size = config.hidden_size + self.mlp = Gemma2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + ) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = self.post_attention_layernorm(hidden_states) + + hidden_states, residual = self.pre_feedforward_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_feedforward_layernorm(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": 0, + "inputs_embeds": 0, + "intermediate_tensors": 0, + }) +class Gemma2Model(nn.Module): + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ + -1]), config, cache_config, quant_config), + prefix=f"{prefix}.layers") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Normalize the embedding by sqrt(hidden_size) + # The normalizer's data type should be downcasted to the model's + # data type such as bfloat16, not float32. + # See https://github.com/huggingface/transformers/pull/29402 + normalizer = self.config.hidden_size**0.5 + self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + hidden_states *= self.normalizer + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + logger.warning( + "Some weights are not initialized from checkpoints: %s", + unloaded_params) + + +class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Gemma2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + del lora_config # Unused. + super().__init__() + self.config = config + # currently all existing Gemma models have `tie_word_embeddings` enabled + assert config.tie_word_embeddings + self.quant_config = quant_config + self.model = Gemma2Model(config, cache_config, quant_config) + self.logits_processor = LogitsProcessor( + config.vocab_size, soft_cap=config.final_logit_softcapping) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.model.embed_tokens, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/gemma2_embedding.py b/vllm/model_executor/models/gemma2_embedding.py new file mode 100644 index 0000000..e8e1059 --- /dev/null +++ b/vllm/model_executor/models/gemma2_embedding.py @@ -0,0 +1,57 @@ +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .gemma2 import Gemma2Model +from .interfaces import SupportsPP + + +class Gemma2EmbeddingModel(nn.Module, SupportsPP): + """A model that uses Gemma2 with additional embedding functionalities. + + This class encapsulates the Gemma2Model and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of Gemma2Model used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = Gemma2Model(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/glm4.py b/vllm/model_executor/models/glm4.py new file mode 100644 index 0000000..e74fd85 --- /dev/null +++ b/vllm/model_executor/models/glm4.py @@ -0,0 +1,362 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2025 The Zhipu AI team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" +from typing import Iterable, Optional, Set, Tuple, Union, List + +import torch +from torch import nn +from transformers import Glm4Config + +from vllm.attention import Attention, AttentionType, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig , LoRAConfig #, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, Sampler #get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .llama import LlamaMLP as Glm4MLP +from .llama import LlamaModel +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + + +class Glm4Attention(nn.Module): + + def __init__(self, + config: Glm4Config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.rotary_dim = self.head_dim #int(partial_rotary_factor * self.head_dim) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + is_neox_style=False, + ) + # self.attn = Attention(self.num_heads, + # self.head_dim, + # self.scaling, + # num_kv_heads=self.num_kv_heads, + # cache_config=cache_config, + # quant_config=quant_config, + # prefix=f"{prefix}.attn", + # attn_type=attn_type) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Glm4DecoderLayer(nn.Module): + + def __init__( + self, + config: Glm4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + self.self_attn = Glm4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=AttentionType.DECODER, + ) + self.mlp = Glm4MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_mlp_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata + ) + + hidden_states = self.post_self_attn_layernorm(hidden_states) + # hidden_states = residual + hidden_states + + # Fully Connected + # hidden_states = self.post_attention_layernorm(hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_mlp_layernorm(hidden_states) + # hidden_states = residual + hidden_states + + # Fully Connected + # residual = hidden_states + # hidden_states = self.post_attention_layernorm(hidden_states) + # hidden_states = self.mlp(hidden_states) + # hidden_states = self.post_mlp_layernorm(hidden_states) + # hidden_states = residual + hidden_states + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Glm4DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Glm4Model(LlamaModel): + + # def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # super().__init__(vllm_config=vllm_config, + # prefix=prefix, + # layer_type=Glm4DecoderLayer) + + def __init__(self, *, + config: Glm4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = ""): + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config, + prefix=prefix, + layer_type=Glm4DecoderLayer) + + +class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, + config: Glm4Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = ""): + super().__init__() + # config = vllm_config.model_config.hf_config + # quant_config = vllm_config.quant_config + # lora_config = vllm_config.lora_config + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Glm4Model(config=config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py new file mode 100644 index 0000000..2ce8939 --- /dev/null +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/GLM-4 +"""Inference-only GLM-4v model visual encoder compatible with THUDM weights.""" +from argparse import Namespace +from typing import Optional + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class PatchEmbedding(nn.Module): + + def __init__(self, config): + super().__init__() + self.proj = nn.Conv2d(config.in_channels, + config.hidden_size, + kernel_size=config.patch_size, + stride=config.patch_size) + self.cls_embedding = nn.Parameter(torch.zeros(1, config.hidden_size)) + self.position_embedding = nn.Embedding(config.num_positions, + config.hidden_size) + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + images = images.to(self.proj.weight.device) + x = self.proj(images) + x = x.flatten(2).transpose(1, 2) + cls_token = self.cls_embedding.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.position_embedding.weight.unsqueeze(0) + return x + + +class Attention(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = config.num_heads // self.tp_size + self.head_dim = config.hidden_size // config.num_heads + self.scale = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_dim, + config.num_heads, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.output_dropout = torch.nn.Dropout(config.dropout_prob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, L, _ = x.shape + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + q = q.reshape(B, L, self.num_heads_per_rank, + self.head_dim).permute(0, 2, 1, 3) # B, H, L, D + k = k.reshape(B, L, self.num_heads_per_rank, + self.head_dim).permute(0, 2, 1, 3) # B, H, L, D + v = v.reshape(B, L, self.num_heads_per_rank, + self.head_dim).permute(0, 2, 1, 3) # B, H, L, D + + out = torch.nn.functional.scaled_dot_product_attention(q, + k, + v, + attn_mask=None, + dropout_p=0., + is_causal=False) + + output, _ = self.dense(out.transpose(1, 2).contiguous().view(B, L, -1)) + output = self.output_dropout(output) + return output + + +class MLP(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.fc1(x) + x = self.activation_fn(x) + x, _ = self.fc2(x) + return x + + +class TransformerLayer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.input_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = Attention(config, quant_config=quant_config) + self.mlp = MLP(config, quant_config=quant_config) + self.post_attention_layernorm = LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states): + attention_input = hidden_states + attention_output = self.input_layernorm( + self.attention(attention_input)) + hidden_states = attention_input + attention_output + mlp_input = hidden_states + mlp_output = self.post_attention_layernorm(self.mlp(mlp_input)) + output = mlp_input + mlp_output + return output + + +class Transformer(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.layers = nn.ModuleList([ + TransformerLayer(config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward(self, hidden_states): + for layer_module in self.layers: + hidden_states = layer_module(hidden_states) + return hidden_states + + +class GLU(nn.Module): + + def __init__( + self, + config, + in_features, + quant_config: Optional[QuantizationConfig] = None, + ): + """ + The original implementation is the same as: + ```python + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + + self.gate_proj = ColumnParallelLinear( + config.hidden_size, + config.ffn_hidden_size, + bias=False, + quant_config=quant_config + ) + ``` + ``` + gate_proj_output, _ = self.gate_proj(x) + dense_h_to_4h_output, _ = self.dense_h_to_4h(x) + x = torch.cat([gate_proj_output, dense_h_to_4h_output], dim=-1) + ``` + + We merge two ColumnParallelLinear into one MergedColumnParallelLinear: + ``` + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config + ) + ``` + ``` + x, _ = self.merged_proj(x) + ``` + """ + super().__init__() + self.linear_proj = ReplicatedLinear(in_features, + config.hidden_size, + bias=False, + quant_config=quant_config) + self.norm1 = nn.LayerNorm(config.hidden_size) + self.act1 = nn.GELU() + self.act2 = SiluAndMul() + + self.merged_proj = MergedColumnParallelLinear( + config.hidden_size, [config.ffn_hidden_size] * 2, + bias=False, + quant_config=quant_config) + + self.dense_4h_to_h = RowParallelLinear(config.ffn_hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config) + + def forward(self, x): + x, _ = self.linear_proj(x) + x = self.act1(self.norm1(x)) + x, _ = self.merged_proj(x) + x = self.act2(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class EVA2CLIPModel(nn.Module): + + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + vision_config = Namespace(**config.vision_config) + self.patch_embedding = PatchEmbedding(vision_config) + self.transformer = Transformer(vision_config, + quant_config=quant_config) + self.linear_proj = GLU(config, + in_features=config.hidden_size, + quant_config=quant_config) + self.conv = nn.Conv2d(in_channels=vision_config.hidden_size, + out_channels=config.hidden_size, + kernel_size=2, + stride=2) + self.boi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.eoi = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) + self.scaling_factor = vision_config.scaling_factor + + def forward(self, images: torch.Tensor) -> torch.Tensor: + """ + Parameters: + images : torch.Tensor + Input image tensor with shape (B, C, H, W) + + Returns: + torch.Tensor + Transformed tensor with shape (B, L, D) + """ + x = self.patch_embedding(images) + x = self.transformer(x) + x = x[:, 1:] + + b, s, h = x.shape + grid_size = int(s**0.5) + x = x.view(b, grid_size, grid_size, h).permute(0, 3, 1, 2) + x = self.conv(x) + + x = x.flatten(2).transpose(1, 2) + x = self.linear_proj(x) + boi = self.boi.expand(x.shape[0], -1, -1) + eoi = self.eoi.expand(x.shape[0], -1, -1) + x = torch.cat((boi, x, eoi), dim=1) + x = x / self.scaling_factor + return x diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py new file mode 100644 index 0000000..9755023 --- /dev/null +++ b/vllm/model_executor/models/gpt2.py @@ -0,0 +1,324 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import GPT2Config + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed.parallel_state import ( + get_pp_group, get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class GPT2Attention(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_attn", + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPT2MLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPT2Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_fc", + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, + config, + quant_config, + prefix=f"{prefix}.mlp") + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPT2Model(nn.Module): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPT2Block( + config, cache_config, quant_config, prefix=prefix), + prefix=f"{prefix}.h") + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPT2LMHeadModel(nn.Module, SupportsPP): + + def __init__( + self, + config: GPT2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = GPT2Model(config, + cache_config, + quant_config, + prefix="transformer") + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py new file mode 100644 index 0000000..6c4a046 --- /dev/null +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -0,0 +1,341 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import GPTBigCodeConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class GPTBigCodeAttention(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + self.tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.multi_query = config.multi_query + if self.multi_query: + total_num_kv_heads = 1 + self.num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + self.num_kv_heads = self.num_heads + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], + dim=-1, + ) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config, cache_config, quant_config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.wte = VocabParallelEmbedding(self.vocab_size, + self.embed_dim, + org_num_embeddings=config.vocab_size) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = {"c_attn": ["c_attn"]} + + supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] + + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + + embedding_padding_modules = [] + + def __init__( + self, + config: GPTBigCodeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.transformer = GPTBigCodeModel(config, cache_config, quant_config, + lora_config) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead( + self.transformer.vocab_size, + self.transformer.embed_dim, + org_num_embeddings=self.config.vocab_size) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + # TODO (@robertgshaw2-neuralmagic): move to fp8 linear method + if "c_attn.input_scale" in name or "c_attn.weight_scale" in name: + weight_loader(param, loaded_weight, 'q') + weight_loader(param, loaded_weight, 'k') + weight_loader(param, loaded_weight, 'v') + else: + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py new file mode 100644 index 0000000..d40bf8c --- /dev/null +++ b/vllm/model_executor/models/gpt_j.py @@ -0,0 +1,317 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py +# Copyright 2023 The vLLM team. +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-J model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import GPTJConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class GPTJAttention(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=False, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + scaling = self.head_size**-0.5 + assert getattr(config, "rotary", True) + assert config.rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=config.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class GPTJMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTJConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear( + hidden_size, + intermediate_size, + quant_config=quant_config, + ) + self.fc_out = RowParallelLinear( + intermediate_size, + hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + inner_dim = (4 * config.n_embd + if config.n_inner is None else config.n_inner) + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config, cache_config, quant_config) + self.mlp = GPTJMLP(inner_dim, config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + mlp_output = self.mlp(hidden_states) + hidden_states = attn_output + mlp_output + residual + return hidden_states + + +class GPTJModel(nn.Module): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_dim = config.n_embd + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.start_layer, self.end_layer, self.h = make_layers( + config.n_layer, + lambda prefix: GPTJBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h", + ) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTJForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: GPTJConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + assert not config.tie_word_embeddings + self.transformer = GPTJModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.n_embd, + bias=True, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "attn.bias" in name or "attn.masked_bias" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py new file mode 100644 index 0000000..23a1ca0 --- /dev/null +++ b/vllm/model_executor/models/gpt_neox.py @@ -0,0 +1,329 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import GPTNeoXConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class GPTNeoXAttention(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + self.bias = getattr(config, "attention_bias", True) + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=self.bias, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=self.bias, + quant_config=quant_config, + ) + scaling = self.head_size**-0.5 + rotary_dim = int(self.head_size * config.rotary_pct) + assert rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.dense(attn_output) + return output + + +class GPTNeoXMLP(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config, + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.hidden_act, quant_config, + config.intermediate_size) + + def forward(self, hidden_states): + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config, cache_config, quant_config) + self.mlp = GPTNeoXMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + attn_input = self.input_layernorm(hidden_states) + attn_output = self.attention( + position_ids=position_ids, + hidden_states=attn_input, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_input = self.post_attention_layernorm(attn_output) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + return hidden_states + + +class GPTNeoXModel(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + + self.embed_in = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_in(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class GPTNeoXForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: GPTNeoXConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.gpt_neox = GPTNeoXModel(config, cache_config, quant_config) + self.embed_out = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.embed_out.weight = self.gpt_neox.embed_in.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.gpt_neox.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.gpt_neox(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.embed_out, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if ("attention.bias" in name or "attention.masked_bias" in name + or "rotary_emb.inv_freq" in name): + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using OpenRLHF may include + # these tensors in the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py new file mode 100644 index 0000000..dcf4f5b --- /dev/null +++ b/vllm/model_executor/models/granite.py @@ -0,0 +1,545 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import GraniteConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +class GraniteMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +class GraniteModel(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + + hidden_states *= self.config.embedding_multiplier + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = GraniteModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + + if hasattr(config, "logits_scaling"): + logit_scale /= config.logits_scaling + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py new file mode 100644 index 0000000..5266951 --- /dev/null +++ b/vllm/model_executor/models/granitemoe.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GraniteMoe model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.models.granitemoe import GraniteMoeConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .interfaces import SupportsLoRA, SupportsPP +from .utils import make_layers + + +class GraniteMoeMoE(nn.Module): + """A tensor-parallel MoE implementation for GraniteMoe that shards each + expert across all ranks. + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class GraniteMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attention_multiplier: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = (attention_multiplier if attention_multiplier + is not None else self.head_dim**-1) + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +class GraniteMoeModel(nn.Module): + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = GraniteMoeModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + ".block_sparse_moe.experts.%d.w1.weight" % e) + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + ".block_sparse_moe.experts.%d.w3.weight" % e) + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + ".block_sparse_moe.experts.%d.w2.weight" % e) + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + elif n == 'lm_head.weight' and self.config.tie_word_embeddings: + pass + else: + new_weights[n] = p + mixtral.MixtralForCausalLM.load_weights(self, new_weights.items()) diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py new file mode 100644 index 0000000..3b0b6fe --- /dev/null +++ b/vllm/model_executor/models/idefics2_vision_model.py @@ -0,0 +1,302 @@ +# coding=utf-8 + +# adapted from https://github.com/huggingface/transformers/blob/v4.43.2/src/transformers/models/idefics2/modeling_idefics2.py +# Copyright 2024 The vLLM team. +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Idefics2 model.""" + +from typing import Optional + +import torch +from torch import nn +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2Config, Idefics2VisionConfig) +from xformers import ops as xops + +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, + self.embed_dim) + + def forward(self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, + 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, + max_nb_patches_h * max_nb_patches_w), + fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize(fractional_coords_h, + boundaries, + right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, + boundaries, + right=True) + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" # noqa: E501 + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + quant_config=quant_config, + ) + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + self.is_causal = False + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + batch_size, q_len, _ = hidden_states.size() + qkv, _ = self.qkv_proj( + hidden_states + ) # batch_size, q_len, 3 * num_heads_per_partition * head_dim + query_states, key_states, value_states = qkv.chunk(3, dim=-1) + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + # see: https://facebookresearch.github.io/xformers/components/ops.html + out = xops.memory_efficient_attention_forward( + query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale, + ) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) + return attn_output + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: Idefics2Config, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__(self, config: Idefics2Config): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = Idefics2VisionAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__(self, config: Idefics2Config): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + Idefics2EncoderLayer(config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer(hidden_states) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionTransformer(nn.Module): + + def __init__(self, config: Idefics2VisionConfig): + super().__init__() + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes) + encoder_outputs = self.encoder(hidden_states) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py new file mode 100644 index 0000000..dcead65 --- /dev/null +++ b/vllm/model_executor/models/interfaces.py @@ -0,0 +1,352 @@ +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, + Protocol, Type, Union, overload, runtime_checkable) + +import torch +from typing_extensions import TypeIs + +from vllm.logger import init_logger +from vllm.utils import supports_kw + +if TYPE_CHECKING: + from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.sequence import IntermediateTensors + +logger = init_logger(__name__) + + +@runtime_checkable +class SupportsMultiModal(Protocol): + """The interface required for all multi-modal models.""" + + supports_multimodal: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports multi-modal inputs. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsMultiModalType(Protocol): + supports_multimodal: Literal[True] + + def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: + ... + + +@overload +def supports_multimodal( + model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]: + ... + + +@overload +def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: + ... + + +def supports_multimodal( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]: + if isinstance(model, type): + return isinstance(model, _SupportsMultiModalType) + + return isinstance(model, SupportsMultiModal) + + +@runtime_checkable +class SupportsLoRA(Protocol): + """The interface required for all models that support LoRA.""" + + supports_lora: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports LoRA. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + packed_modules_mapping: ClassVar[Dict[str, List[str]]] + supported_lora_modules: ClassVar[List[str]] + embedding_modules: ClassVar[Dict[str, str]] + embedding_padding_modules: ClassVar[List[str]] + + # lora_config is None when LoRA is not enabled + def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsLoRAType(Protocol): + supports_lora: Literal[True] + + packed_modules_mapping: Dict[str, List[str]] + supported_lora_modules: List[str] + embedding_modules: Dict[str, str] + embedding_padding_modules: List[str] + + def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: + ... + + +@overload +def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]: + ... + + +@overload +def supports_lora(model: object) -> TypeIs[SupportsLoRA]: + ... + + +def supports_lora( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]: + result = _supports_lora(model) + + if not result: + lora_attrs = ( + "packed_modules_mapping", + "supported_lora_modules", + "embedding_modules", + "embedding_padding_modules", + ) + missing_attrs = tuple(attr for attr in lora_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_lora", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_lora=True`, " + "but is missing LoRA-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all LoRA-specific attributes, " + "but does not set `supports_lora=True`.", model) + + return result + + +def _supports_lora(model: Union[Type[object], object]) -> bool: + if isinstance(model, type): + return isinstance(model, _SupportsLoRAType) + + return isinstance(model, SupportsLoRA) + + +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + + def forward( + self, + *, + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + """ + Accept :class:`IntermediateTensors` when PP rank > 0. + + Return :class:`IntermediateTensors` only for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + *, + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + ... + + +@overload +def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect + + +def _supports_pp_attributes(model: Union[Type[object], object]) -> bool: + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + +def _supports_pp_inspect(model: Union[Type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + return supports_kw(model_forward, "intermediate_tensors") + + +@runtime_checkable +class HasInnerState(Protocol): + """The interface required for all models that has inner state.""" + + has_inner_state: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has inner state. + Models that has inner state usually need access to the scheduler_config + for max_num_seqs, etc. True for e.g. both Mamba and Jamba. + """ + + def __init__(self, + *, + scheduler_config: Optional["SchedulerConfig"] = None) -> None: + ... + + +@runtime_checkable +class _HasInnerStateType(Protocol): + has_inner_state: ClassVar[Literal[True]] + + def __init__(self, + *, + scheduler_config: Optional["SchedulerConfig"] = None) -> None: + ... + + +@overload +def has_inner_state(model: object) -> TypeIs[HasInnerState]: + ... + + +@overload +def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]: + ... + + +def has_inner_state( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]: + if isinstance(model, type): + return isinstance(model, _HasInnerStateType) + + return isinstance(model, HasInnerState) + + +@runtime_checkable +class IsAttentionFree(Protocol): + """The interface required for all models like Mamba that lack attention, + but do have state whose size is constant wrt the number of tokens.""" + + is_attention_free: ClassVar[Literal[True]] = True + """ + A flag that indicates this model has no attention. + Used for block manager and attention backend selection. + True for Mamba but not Jamba. + """ + + def __init__(self) -> None: + ... + + +@runtime_checkable +class _IsAttentionFreeType(Protocol): + is_attention_free: ClassVar[Literal[True]] + + def __init__(self) -> None: + ... + + +@overload +def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: + ... + + +@overload +def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]: + ... + + +def is_attention_free( + model: Union[Type[object], object] +) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]: + if isinstance(model, type): + return isinstance(model, _IsAttentionFreeType) + + return isinstance(model, IsAttentionFree) diff --git a/vllm/model_executor/models/interfaces_base.py b/vllm/model_executor/models/interfaces_base.py new file mode 100644 index 0000000..8d2d422 --- /dev/null +++ b/vllm/model_executor/models/interfaces_base.py @@ -0,0 +1,191 @@ +from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union, + overload, runtime_checkable) + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from typing_extensions import TypeIs, TypeVar + +from vllm.logger import init_logger +from vllm.utils import supports_kw + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import CacheConfig + from vllm.model_executor.layers.pooler import PoolerOutput + from vllm.model_executor.layers.quantization import QuantizationConfig + from vllm.model_executor.layers.sampler import SamplerOutput + from vllm.model_executor.pooling_metadata import PoolingMetadata + from vllm.model_executor.sampling_metadata import SamplingMetadata + +logger = init_logger(__name__) + +# The type of HF config +C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True) + +# The type of hidden states +# Currently, T = torch.Tensor for all models except for Medusa +# which has T = List[torch.Tensor] +T = TypeVar("T", default=torch.Tensor) +T_co = TypeVar("T_co", default=torch.Tensor, covariant=True) + +# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags +# for the base interfaces to avoid breaking OOT registration for existing models +# that don't inherit from the base interface classes + + +@runtime_checkable +class VllmModel(Protocol[C_co, T_co]): + + def __init__( + self, + config: C_co, + *, + cache_config: Optional["CacheConfig"], + quant_config: Optional["QuantizationConfig"], + ) -> None: + ... + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + ) -> T_co: + ... + + +def _check_vllm_model_init(model: Union[Type[object], object]) -> bool: + model_init = model.__init__ + vllm_kws = ("cache_config", "quant_config") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_init, kw)) + + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its initializer: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata") + missing_kws = tuple(kw for kw in vllm_kws + if not supports_kw(model_forward, kw)) + + if missing_kws and (isinstance(model, type) + and issubclass(model, nn.Module)): + logger.warning( + "The model (%s) is missing " + "vLLM-specific keywords from its initializer: %s", + model, + missing_kws, + ) + + return len(missing_kws) == 0 + + +@overload +def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]: + ... + + +@overload +def is_vllm_model(model: object) -> TypeIs[VllmModel]: + ... + + +def is_vllm_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]: + return _check_vllm_model_init(model) and _check_vllm_model_forward(model) + + +@runtime_checkable +class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]): + + def compute_logits( + self, + hidden_states: T, + sampling_metadata: "SamplingMetadata", + ) -> Optional[T]: + """Return `None` if TP rank > 0.""" + ... + + def sample( + self, + logits: T, + sampling_metadata: "SamplingMetadata", + ) -> "SamplerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_text_generation_model( + model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]: + ... + + +@overload +def is_text_generation_model( + model: object) -> TypeIs[VllmModelForTextGeneration]: + ... + + +def is_text_generation_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForTextGeneration]], + TypeIs[VllmModelForTextGeneration]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForTextGeneration) + + return isinstance(model, VllmModelForTextGeneration) + + +@runtime_checkable +class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]): + + def pooler( + self, + hidden_states: T, + pooling_metadata: "PoolingMetadata", + ) -> "PoolerOutput": + """Only called on TP rank 0.""" + ... + + +@overload +def is_embedding_model( + model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]: + ... + + +@overload +def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]: + ... + + +def is_embedding_model( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]: + if not is_vllm_model(model): + return False + + if isinstance(model, type): + return isinstance(model, VllmModelForEmbedding) + + return isinstance(model, VllmModelForEmbedding) diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py new file mode 100644 index 0000000..35be1ce --- /dev/null +++ b/vllm/model_executor/models/intern_vit.py @@ -0,0 +1,428 @@ +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from functools import partial +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.weight_utils import default_weight_loader + +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + +NORM2FN = { + 'rms_norm': RMSNorm, + 'layer_norm': nn.LayerNorm, +} + + +class InternVisionEmbeddings(nn.Module): + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim)) + + self.patch_embedding = nn.Conv2d(in_channels=3, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + 1 + + self.position_embedding = nn.Parameter( + torch.randn(1, self.num_positions, self.embed_dim)) + + def _get_pos_embed(self, pos_embed: torch.Tensor, H: int, W: int): + target_dtype = pos_embed.dtype + pos_embed = pos_embed.float().reshape( + 1, self.image_size // self.patch_size, + self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) + pos_embed = F.interpolate(pos_embed, + size=(H, W), + mode='bicubic', + align_corners=False) + return pos_embed.reshape(1, -1, H * W).permute(0, 2, + 1).to(target_dtype) + + def _get_position_embedding(self, H: int, W: int) -> torch.Tensor: + position_embedding = self.position_embedding + if self.num_patches == H * W: + return position_embedding + + return torch.cat( + [ + position_embedding[:, :1, :], + self._get_pos_embed(position_embedding[:, 1:, :], H, W), + ], + dim=1, + ) + + def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + target_dtype)) # shape = [*, channel, width, height] + batch_size, _, height, width = patch_embeds.shape + patch_embeds = patch_embeds.flatten(2).transpose(1, 2) + class_embeds = self.class_embedding.expand(batch_size, 1, + -1).to(target_dtype) + embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + position_embedding = self._get_position_embedding(height, width) + embeddings = embeddings + position_embedding.to(target_dtype) + return embeddings + + +class InternParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_dummy_heads: int = 0, + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim + self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads, + self.tp_size) + + self.scale = self.head_dim**-0.5 + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + num_dummy_heads + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + self.k_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + + self.proj = RowParallelLinear( + self.dummy_dim, + self.embed_dim, + quant_config=quant_config, + ) + + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, _ = x.shape + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + if self.qk_normalization: + q, k = self._apply_qk_norm(q, k) + + q = q.view(B, N, self.num_heads_per_partition, self.head_dim) + k = k.view(B, N, self.num_heads_per_partition, self.head_dim) + v = v.view(B, N, self.num_heads_per_partition, self.head_dim) + + x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) + x = x.view(B, N, -1) + + x, _ = self.proj(x) + return x + + +class InternSdpaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + *, + num_dummy_heads: int = 0, + ) -> None: + super().__init__() + + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim + + self.scale = self.head_dim**-0.5 + self.qkv = nn.Linear(self.embed_dim, + 3 * self.dummy_dim, + bias=config.qkv_bias) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + self.k_norm = RMSNorm(self.dummy_dim, + eps=config.layer_norm_eps, + var_hidden_size=self.embed_dim) + + self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(B, N, self.num_heads, self.head_dim) + k = k.view(B, N, self.num_heads, self.head_dim) + v = v.view(B, N, self.num_heads, self.head_dim) + + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) + x = x.transpose(1, 2).view(B, N, -1) + + x = self.proj(x) + return x + + +class InternMLP(nn.Module): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.fc2 = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + return hidden_states + + +class InternVisionEncoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_dummy_heads: int = 0, + ) -> None: + super().__init__() + + self.embed_dim = config.hidden_size + self.intermediate_size = config.intermediate_size + self.norm_type = config.norm_type + + self.attn = self._init_attn(config, + quant_config, + num_dummy_heads=num_dummy_heads) + + self.mlp = InternMLP(config, quant_config=quant_config) + self.norm1 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + self.norm2 = NORM2FN[self.norm_type](self.embed_dim, + eps=config.layer_norm_eps) + + self.ls1 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + self.ls2 = nn.Parameter(config.initializer_factor * + torch.ones(self.embed_dim)) + + def _init_attn( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + *, + num_dummy_heads: int, + ): + # fallback to sdpa attention if tp unavailable + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + + if USE_XFORMERS_OPS and (num_heads + num_dummy_heads) % tp_size == 0: + return InternParallelAttention(config, + quant_config=quant_config, + num_dummy_heads=num_dummy_heads) + + return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads) + + def forward( + self, + hidden_states: torch.Tensor, + ): + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states)) * self.ls1 + + hidden_states = hidden_states + self.mlp( + self.norm2(hidden_states)) * self.ls2 + + return hidden_states + + +class InternVisionEncoder(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + ): + super().__init__() + + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + InternVisionEncoderLayer(config, + quant_config, + num_dummy_heads=num_dummy_heads) + for _ in range(num_hidden_layers) + ]) + + def forward(self, inputs_embeds: torch.Tensor): + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer(hidden_states) + + return hidden_states + + +class InternVisionModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + num_hidden_layers_override: Optional[int] = None, + num_dummy_heads: int = 0, + ): + super().__init__() + + self.config = config + + self.embeddings = InternVisionEmbeddings(config) + self.encoder = InternVisionEncoder( + config=config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + num_dummy_heads=num_dummy_heads, + ) + + def get_input_embeddings(self): + return self.embeddings + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + pixel_embeds: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + if pixel_values is None and pixel_embeds is None: + raise ValueError( + 'You have to specify pixel_values or pixel_embeds') + + if pixel_embeds is not None: + hidden_states = pixel_embeds + elif pixel_values is not None: + if pixel_values.ndim == 4: + hidden_states = self.embeddings(pixel_values) + else: + raise ValueError( + f'wrong pixel_values size: {pixel_values.shape}') + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + return encoder_outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py new file mode 100644 index 0000000..f6cde44 --- /dev/null +++ b/vllm/model_executor/models/internlm2.py @@ -0,0 +1,383 @@ +# -*- coding: utf-8 -*- +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class InternLM2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.w2 = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class InternLM2Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= self.tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.key_value_groups = int(self.num_heads / self.num_kv_heads) + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.wqkv = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def split_qkv(self, qkv: torch.Tensor): + seq_len = qkv.shape[0] + if self.tp_size > 1: + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size + qkv = tensor_model_parallel_all_gather(qkv) + qkv = torch.split(qkv, qkv_map, dim=-1) + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] + qkv = torch.cat(qkv, dim=-1) + + qkv = qkv.view(seq_len, self.total_num_kv_heads, + self.key_value_groups + 2, self.head_dim) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) + q = q.reshape(seq_len, self.q_size * self.tp_size) + k = k.reshape(seq_len, self.kv_size * self.tp_size) + v = v.reshape(seq_len, self.kv_size * self.tp_size) + + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] + return q, k, v + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.wqkv(hidden_states) + q, k, v = self.split_qkv(qkv) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.wo(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class InternLM2Model(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.tok_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2ForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = InternLM2Model(config, cache_config, quant_config) + self.output = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.output.weight = self.model.tok_embeddings.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.output, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py new file mode 100644 index 0000000..9024831 --- /dev/null +++ b/vllm/model_executor/models/internvl.py @@ -0,0 +1,612 @@ +# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2023 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +import re +from functools import cached_property, partial +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from transformers import PretrainedConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.models.intern_vit import InternVisionModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, + get_clip_num_patches) +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) + +IMG_START = '' +IMG_END = '' +IMG_CONTEXT = '' + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +class InternVLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` + """ + + +class InternVLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +InternVLImageInputs = Union[InternVLImagePixelInputs, + InternVLImageEmbeddingInputs] + + +# copied from https://huggingface.co/OpenGVLab/InternVL2-1B +def build_transform(input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), + interpolation=T.InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + + +# copied from https://huggingface.co/OpenGVLab/InternVL2-1B +def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, + image_size): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + +def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, + max_num: int, image_size: int, + use_thumbnail: bool) -> Tuple[int, int, int]: + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set((i, j) for n in range(min_num, max_num + 1) + for i in range(1, n + 1) for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, + target_ratios, orig_width, + orig_height, image_size) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # add thumbnail image if num_blocks > 1 + if use_thumbnail and blocks > 1: + blocks += 1 + return blocks, target_width, target_height + + +def calculate_num_blocks_wrapper(hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None): + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + min_num = hf_config.min_dynamic_patch + image_size = hf_config.vision_config.image_size + use_thumbnail = hf_config.use_thumbnail + return partial(calculate_num_blocks, + min_num=min_num, + max_num=max_dynamic_patch, + image_size=image_size, + use_thumbnail=use_thumbnail) + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, + image_size: int, + use_thumbnail: bool) -> List[Image.Image]: + orig_width, orig_height = image.size + + # calculate the number of blocks without thumbnail + blocks, target_width, target_height = calculate_num_blocks( + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False) + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ((i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + +# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B +def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, + max_num: int, use_thumbnail: bool) -> torch.Tensor: + transform = build_transform(input_size=input_size) + images = dynamic_preprocess(image, + min_num=min_num, + max_num=max_num, + image_size=input_size, + use_thumbnail=use_thumbnail) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + +def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None): + image_size = hf_config.vision_config.image_size + min_num = hf_config.min_dynamic_patch + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail + return partial(image_to_pixel_values, + input_size=image_size, + min_num=min_num, + max_num=max_dynamic_patch, + use_thumbnail=use_thumbnail) + + +def get_internvl_num_patches(hf_config: PretrainedConfig): + vision_config = hf_config.vision_config + downsample_ratio = hf_config.downsample_ratio + image_size = vision_config.image_size + patch_size = vision_config.patch_size + return int( + get_clip_num_patches(image_size=image_size, patch_size=patch_size) * + (downsample_ratio**2)) + + +def get_max_internvl_image_tokens(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): + hf_config = ctx.get_hf_config() + + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail + if use_thumbnail and max_dynamic_patch > 1: + max_dynamic_patch += 1 + + num_patches = get_internvl_num_patches(hf_config) + return num_patches * max_dynamic_patch + + +def get_max_internvl_image_size(ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None): + hf_config = ctx.get_hf_config() + image_size = hf_config.vision_config.image_size + + if max_dynamic_patch is None: + max_dynamic_patch = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail + if use_thumbnail and max_dynamic_patch > 1: + max_dynamic_patch += 1 + width = image_size * max_dynamic_patch + height = image_size + return width, height + + +class InternVLInputPipeline: + + def __init__( + self, + img_start_token: str, + img_end_token: str, + img_context_token: str, + ) -> None: + super().__init__() + + self.img_start_token = img_start_token + self.img_end_token = img_end_token + self.img_context_token = img_context_token + + def _create_image_prompt(self, feature_size: int, num_patches: int) -> str: + return (self.img_start_token + self.img_context_token * feature_size + + self.img_end_token) + + def _expand_image_prompt( + self, + prompt: str, + feature_sizes: List[int], + num_patches: int, + ) -> str: + image_idx = sorted( + map(int, re.findall(r"Image-(\d+): \n", prompt))) + + new_prompt = prompt + for idx, feature_size in enumerate(feature_sizes, start=1): + image_prompt = self._create_image_prompt(feature_size, num_patches) + if not image_idx: + image_prompt = f"Image-{idx}: {image_prompt}" + + new_prompt = new_prompt.replace('', image_prompt, 1) + + return new_prompt + + def input_processor( + self, + ctx: InputContext, + llm_inputs: LLMInputs, + *, + max_dynamic_patch: Optional[int] = None, + ) -> LLMInputs: + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config() + + image_data = multi_modal_data["image"] + num_patches = get_internvl_num_patches(hf_config) + num_blocks_calculator = calculate_num_blocks_wrapper( + hf_config, max_dynamic_patch) + if isinstance(image_data, Image.Image): + width, height = image_data.size + num_blocks, _, _ = num_blocks_calculator(width, height) + image_feature_sizes = [num_blocks * num_patches] + elif is_list_of(image_data, Image.Image): + image_feature_sizes = [] + for image in image_data: + width, height = image.size + num_blocks, _, _ = num_blocks_calculator(width, height) + image_feature_sizes.append(num_blocks * num_patches) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + image_feature_sizes = [image_feature_size] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + + prompt = llm_inputs.get("prompt") + prompt_token_ids = llm_inputs["prompt_token_ids"] + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + new_prompt = self._expand_image_prompt(prompt, image_feature_sizes, + num_patches) + new_prompt_token_ids = tokenizer.encode(new_prompt) + + return LLMInputs(prompt=prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) + + def input_mapper( + self, + ctx: InputContext, + data: object, + *, + max_dynamic_patch: Optional[int] = None, + ): + hf_config = ctx.get_hf_config() + + image_pixel_values_mapper = image_to_pixel_values_wrapper( + hf_config, max_dynamic_patch) + if isinstance(data, Image.Image): + data = image_pixel_values_mapper(data) + # Add an N dimension for number of images per prompt (currently 1). + data = data.unsqueeze(0) + elif is_list_of(data, Image.Image): + # we can't stack here because images may have different num_patches + data = [image_pixel_values_mapper(img) for img in data] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + image_token_id = tokenizer.encode(self.img_context_token, + add_special_tokens=False, + return_tensors="pt")[0] + + return MultiModalInputs({ + "pixel_values": data, + "image_token_id": image_token_id + }) + + def dummy_data( + self, + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + max_dynamic_patch: Optional[int] = None, + ): + num_images = mm_counts["image"] + + hf_config = ctx.get_hf_config() + + image_feature_size = get_max_internvl_image_tokens( + ctx, max_dynamic_patch=max_dynamic_patch) + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + + seq_data = dummy_seq_data_for_clip( + hf_config.vision_config, + seq_len, + num_images, + image_token_id=tokenizer.encode(self.img_context_token, + add_special_tokens=False)[0], + image_feature_size_override=image_feature_size, + ) + + max_image_width, max_image_height = get_max_internvl_image_size( + ctx, max_dynamic_patch=max_dynamic_patch) + + mm_data = dummy_image_for_clip( + hf_config.vision_config, + num_images, + image_width_override=max_image_width, + image_height_override=max_image_height, + ) + + return seq_data, mm_data + + +input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) +@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) +@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + image_size = config.force_image_size or config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.patch_size = patch_size + self.select_layer = config.select_layer + self.num_image_token = int( + (image_size // patch_size)**2 * (config.downsample_ratio**2)) + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + vision_feature_layer = self.select_layer + if vision_feature_layer < 0: + num_hidden_layers = config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + self.vision_model = self._init_vision_model(config, num_hidden_layers) + + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.mlp1 = self._init_mlp1(config) + + self.img_context_token_id = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _init_vision_model(self, config: PretrainedConfig, + num_hidden_layers: int): + return InternVisionModel(config.vision_config, + num_hidden_layers_override=num_hidden_layers) + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_hidden_size), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size), + ) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version == 'v1': + pass + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: + vit_embeds = self.vision_model(pixel_values=pixel_values) + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1]**0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, + scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, + vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[InternVLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_token_id = kwargs.pop("image_token_id", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return InternVLImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + self.img_context_token_id = image_token_id[0] + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + # We need to flatten (B, N, P) to (B*N*P), + # so we call flatten_bn twice. + return InternVLImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(flatten_bn(pixel_values), concat=True)), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: InternVLImageInputs, + ) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_model is not None + image_embeds = self.extract_feature(image_input["data"]) + + return image_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is not None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.img_context_token_id) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py new file mode 100644 index 0000000..c5e5393 --- /dev/null +++ b/vllm/model_executor/models/jais.py @@ -0,0 +1,374 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights +# reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Jais model compatible with HuggingFace weights.""" + +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import JAISConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class SwiGLUActivation(nn.Module): + + def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + return x1 * nn.functional.silu(x2) + + +def _get_alibi_slopes(n): + + def get_slopes_power_of_2(n): + start = 2**(-(2**-(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2**math.floor(math.log2(n)) + return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes( + 2 * closest_power_of_2)[0::2][:n - closest_power_of_2]) + + +class JAISAttention(nn.Module): + + def __init__( + self, + config: JAISConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + if hasattr(config, "scale_qk_dot_by_d"): + config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d + self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5 + self.scale = self.head_dim**-self.attn_scale_power + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end] + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scale, + alibi_slopes=alibi_slopes, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class JAISMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: JAISConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.swiglu = config.activation_function == "swiglu" + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) + self.c_fc2 = (ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config, + ) if self.swiglu else None) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + self.act = SwiGLUActivation() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.swiglu: + hidden_states2, _ = self.c_fc2(hidden_states) + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = (self.act(hidden_states, hidden_states2) + if self.swiglu else self.act(hidden_states)) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class JAISBlock(nn.Module): + + def __init__( + self, + config: JAISConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = JAISAttention(config, cache_config, quant_config) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = JAISMLP(inner_dim, config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class JAISModel(nn.Module): + + def __init__( + self, + config: JAISConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = (nn.Embedding(config.max_position_embeddings, + self.embed_dim) + if config.position_embedding_type != "alibi" else None) + if hasattr(config, "embeddings_scale"): + self.embeddings_scale = config.embeddings_scale + else: + self.embeddings_scale = config.mup_embeddings_scale + + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: JAISBlock(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.h", + ) + + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[IntermediateTensors, torch.Tensor]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + if self.wpe is not None: + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = inputs_embeds + hidden_states *= torch.tensor(float(self.embeddings_scale), + dtype=hidden_states.dtype) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class JAISLMHeadModel(nn.Module, SupportsPP): + + def __init__( + self, + config: JAISConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.transformer = JAISModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.lm_head = self.transformer.wte + else: + self.lm_head = ParallelLMHead(self.config.vocab_size, + self.config.hidden_size) + if hasattr(config, "width_scale"): + self.output_logits_scale = config.width_scale + else: + self.output_logits_scale = (config.mup_output_alpha * + config.mup_width_scale) + self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, + scale=self.output_logits_scale) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[IntermediateTensors, torch.Tensor]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if "relative_pe" in name: + continue + if not name.startswith("transformer."): + name = "transformer." + name + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py new file mode 100644 index 0000000..ac251b8 --- /dev/null +++ b/vllm/model_executor/models/jamba.py @@ -0,0 +1,717 @@ +# coding=utf-8 +"""Inference-only Jamba model.""" +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import JambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.attention.layer import Attention +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +from .interfaces import HasInnerState, SupportsLoRA + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class JambaMambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: JambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.mamba_d_state + self.conv_kernel_size = config.mamba_d_conv + self.intermediate_size = config.mamba_expand * config.hidden_size + self.time_step_rank = config.mamba_dt_rank + self.use_conv_bias = config.mamba_conv_bias + self.use_bias = config.mamba_proj_bias + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=self.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=self.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=self.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + self.dt_layernorm = RMSNorm(self.time_step_rank, + eps=config.rms_norm_eps) + self.b_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + self.c_layernorm = RMSNorm(self.ssm_state_size, + eps=config.rms_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, conv_state: torch.Tensor, + ssm_state: torch.Tensor): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + ) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class JambaMoE(nn.Module): + + def __init__(self, + config: JambaConfig, + num_experts: Optional[int] = None, + top_k: Optional[int] = None, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.num_total_experts = num_experts or config.num_experts + self.top_k = top_k or config.num_experts_per_tok + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + if self.num_total_experts > 1: + self.router = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None, + params_dtype=params_dtype) + + self.experts = FusedMoE(self.num_total_experts, + self.top_k, + self.hidden_size, + self.intermediate_size, + tp_size=tp_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + use_grouped_topk=False, + quant_config=quant_config) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + if self.num_total_experts > 1: + router_logits, _ = self.router(hidden_states) + else: + router_logits = torch.ones((hidden_states.shape[0], 1), + device=hidden_states.device, + dtype=hidden_states.dtype) + hidden_states = self.experts(hidden_states, router_logits) + return hidden_states.view(orig_shape) + + +class JambaMLP(JambaMoE): + + def __init__(self, + config: JambaConfig, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__(config, + num_experts=1, + top_k=1, + params_dtype=params_dtype, + tp_size=tp_size, + quant_config=quant_config) + + +class JambaMambaDecoderLayer(nn.Module): + + def __init__(self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mamba = JambaMambaMixer(config, layer_idx) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.mamba(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class JambaAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: JambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + ) + + num_experts = config.layers_num_experts[layer_idx] + ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP + self.feed_forward = ffn_layer_class(config, quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + **kwargs, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": JambaAttentionDecoderLayer, + "mamba": JambaMambaDecoderLayer +} + + +class JambaModel(nn.Module): + + def __init__( + self, + config: JambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[i]] + decoder_layers.append( + layer_class(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.final_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + kv_cache = None + current_ssm_state = None + current_conv_state = None + if isinstance(layer, JambaAttentionDecoderLayer): + kv_cache = kv_caches[(i - self.config.attn_layer_offset) // + self.config.attn_layer_period] + if isinstance(layer, JambaMambaDecoderLayer): + current_state_layer = i - (1 + + (i - self.config.attn_layer_offset) + // self.config.attn_layer_period) + current_ssm_state = ssm_state[current_state_layer] + current_conv_state = conv_state[current_state_layer] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.final_layernorm(hidden_states, residual) + return hidden_states + + +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: JambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + assert not cache_config.enable_prefix_caching, \ + "Jamba currently does not support prefix caching" + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.model = JambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + + layers_type = self.config.layers_block_type + num_mamba_layers = sum( + [layer_type == "mamba" for layer_type in layers_type]) + + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, num_mamba_layers, max_batch_size, + *self._get_mamba_cache_shape()) + + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) + + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + hidden_size = self.config.hidden_size + conv_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_conv - 1, + ) + temporal_state_shape = ( + self.config.mamba_expand * hidden_size // world_size, + self.config.mamba_d_state, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + if "feed_forward" in name and not _is_moe_layer(name): + ## map MLP layers to expert with ID=0 + name = name.replace("feed_forward", "feed_forward.experts.0") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if 'experts' in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in expert_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +def _is_moe_layer(name: str): + return any( + [experts_name in name for experts_name in [ + "experts", + "router", + ]]) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py new file mode 100644 index 0000000..a6464fd --- /dev/null +++ b/vllm/model_executor/models/llama.py @@ -0,0 +1,637 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, Type + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = LlamaAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": 0, + "inputs_embeds": 0, + "intermediate_tensors": 0, + }) +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + layer_type: Type[LlamaDecoderLayer] = LlamaDecoderLayer + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + # lambda prefix: LlamaDecoderLayer(config=config, + # cache_config=cache_config, + # quant_config=quant_config, + # prefix=prefix), + lambda prefix: layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + + # from ixformer.inference.overlap import create_vllm_llama_decoder_layer + from vllm.distributed import get_tp_group + + # for layer_idx in range(self.start_layer, self.end_layer): + # self.layers[layer_idx] = create_vllm_llama_decoder_layer( + # self.layers[layer_idx], + # model_id = id(self), + # layer_idx = layer_idx, + # enable_overlap = True, + # group = get_tp_group(), + # quant_config=quant_config, + # activation_dtype=torch.get_default_dtype() + # ) + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") + + +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings" + } + embedding_padding_modules = ["lm_head"] + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm" + } + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = LlamaModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights( + self.model.embed_tokens) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights( + self.maybe_remap_mistral(name, loaded_weight) + for name, loaded_weight in weights) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + + # This function is used to remap the mistral format as + # used by Mistral and Llama <=2 + def maybe_remap_mistral( + self, + name: str, + loaded_weight: torch.Tensor, + ) -> Tuple[str, torch.Tensor]: + + def permute(w: torch.Tensor, n_heads: int): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + if "wk" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif "wq" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + for item in modules: + if item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight \ No newline at end of file diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py new file mode 100644 index 0000000..13574e8 --- /dev/null +++ b/vllm/model_executor/models/llama_embedding.py @@ -0,0 +1,59 @@ +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsPP +from .llama import LlamaModel + + +class LlamaEmbeddingModel(nn.Module, SupportsPP): + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.model(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.model.load_weights(weights) + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py new file mode 100644 index 0000000..864b9ff --- /dev/null +++ b/vllm/model_executor/models/llava.py @@ -0,0 +1,411 @@ +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +from PIL import Image +from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import (CLIPVisionModel, dummy_image_for_clip, + dummy_seq_data_for_clip, get_max_clip_image_tokens, + input_processor_for_clip) +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens, + input_processor_for_siglip) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) + + +class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class LlavaImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs] + + +# TODO(xwjiang): Run benchmark and decide if TP. +class LlavaMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, text_hidden_size: int, + projector_hidden_act: str): + super().__init__() + + self.linear_1 = nn.Linear(vision_hidden_size, + text_hidden_size, + bias=True) + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = nn.Linear(text_hidden_size, + text_hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +def get_max_llava_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + num_image_tokens = get_max_clip_image_tokens(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_image_tokens = get_max_siglip_image_tokens(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + return num_image_tokens - 1 + elif strategy == "full": + return num_image_tokens + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + +def dummy_data_for_llava(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + image_feature_size = get_max_llava_image_tokens(ctx) + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_clip(vision_config, num_images) + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip(vision_config, num_images) + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaConfig) + vision_config = hf_config.vision_config + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_max_llava_image_tokens(ctx) + elif is_list_of(image_data, Image.Image): + image_feature_size = [get_max_llava_image_tokens(ctx) + ] * len(image_data) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + if isinstance(vision_config, CLIPVisionConfig): + return input_processor_for_clip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava) +class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, + config: LlavaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = _init_vision_tower(config) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act) + + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return LlavaImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return LlavaImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds, concat=True), + ) + + raise AssertionError("This line should be unreachable.") + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + def _process_image_pixels(self, + inputs: LlavaImagePixelInputs) -> torch.Tensor: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + return self._image_pixels_to_features(self.vision_tower, pixel_values) + + def _process_image_input(self, + image_input: LlavaImageInputs) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + image_features = self._process_image_pixels(image_input) + return self.multi_modal_projector(image_features) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LLaVA-1.5. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"USER: \\nWhat's the content of the image?\\nASSISTANT:"`. + + Tokenizer outputs: + `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879, + 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618, + 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, + 29901]`. + + We insert 575 tokens so that including the original image token in the + input, there are a total of 576 (24 * 24) image tokens, which + corresponds to the number of image tokens inputted to the language + model, i.e. the number of image tokens outputted by the visual encoder. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each input image. + + See also: + :class:`LlavaImageInputs` + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + else: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + input_ids = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py new file mode 100644 index 0000000..766f6a4 --- /dev/null +++ b/vllm/model_executor/models/llava_next.py @@ -0,0 +1,645 @@ +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn as nn +from PIL import Image +from transformers import CLIPVisionConfig, LlavaNextConfig, SiglipVisionConfig +from transformers.models.llava_next.modeling_llava_next import ( + get_anyres_image_grid_shape, unpad_image) +from typing_extensions import NotRequired + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import (CLIPVisionModel, dummy_image_for_clip, + dummy_seq_data_for_clip, get_clip_image_feature_size, + get_clip_patch_grid_length, input_processor_for_clip) +from .interfaces import SupportsMultiModal, SupportsPP +from .llava import LlavaMultiModalProjector +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_siglip_image_feature_size, + get_siglip_patch_grid_length, input_processor_for_siglip) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) + +# Result in the max possible feature size (2x2 grid of 336x336px tiles) +MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 + + +class LlavaNextImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: NotRequired[torch.Tensor] + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class LlavaNextImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +LlavaNextImageInputs = Union[LlavaNextImagePixelInputs, + LlavaNextImageEmbeddingInputs] + + +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L79 +def _get_llava_next_num_unpadded_features( + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, +) -> Tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + current_height -= 2 * padding + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + current_width -= 2 * padding + + unpadded_features = current_height * current_width + newline_features = current_height + return (unpadded_features, newline_features) + + +# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.2.0/server/text_generation_server/models/vlm_causal_lm.py#L106 +def get_llava_next_image_feature_size( + hf_config: LlavaNextConfig, + *, + input_height: int, + input_width: int, +) -> int: + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + num_patches = get_clip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, + ) + base_feature_size = get_clip_image_feature_size(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_patches = get_siglip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, + ) + base_feature_size = get_siglip_image_feature_size(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + base_feature_size -= 1 + elif strategy == "full": + pass + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(input_height, input_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_config.image_size, + ) + + ( + unpadded_feature_size, + newline_feature_size, + ) = _get_llava_next_num_unpadded_features(input_height, input_width, + num_patches, num_patch_height, + num_patch_width) + + return unpadded_feature_size + newline_feature_size + base_feature_size + + +def get_max_llava_next_image_tokens(ctx: InputContext): + return get_llava_next_image_feature_size( + ctx.get_hf_config(LlavaNextConfig), + input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + ) + + +def dummy_data_for_llava_next(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(LlavaNextConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + image_feature_size = get_max_llava_next_image_tokens(ctx) + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_clip( + vision_config, + num_images, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + mm_data = dummy_image_for_siglip( + vision_config, + num_images, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaNextConfig) + vision_config = hf_config.vision_config + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + width, height = image_data.size + + image_feature_size = get_llava_next_image_feature_size( + hf_config, + input_height=height, + input_width=width, + ) + elif is_list_of(image_data, Image.Image): + image_feature_size = [ + get_llava_next_image_feature_size(hf_config, + input_height=img.height, + input_width=img.width) + for img in image_data + ] + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + return input_processor_for_clip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaNextConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, + config: LlavaNextConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_tower = _init_vision_tower(config) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + self.multi_modal_projector = LlavaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act) + + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaNextImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaNextImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}") + + return LlavaNextImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, + patch_embeddings: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + + # image_aspect_ratio == "anyres" + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ + .view(num_patch_height, num_patch_width, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + (orig_height, orig_width)) + other_patch_embeds = torch.cat(( + other_patch_embeds, + self.image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, + inputs: LlavaNextImagePixelInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + if isinstance(pixel_values, torch.Tensor): + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + stacked_patch_embeddings = self.multi_modal_projector( + stacked_image_features) + + return stacked_patch_embeddings.view( + b, num_patches, *stacked_patch_embeddings.shape[1:]) + + num_patches_per_batch = [v.shape[0] for v in pixel_values] + stacked_pixel_values = torch.cat(pixel_values) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return [ + self.multi_modal_projector(image_features) for image_features in + torch.split(stacked_image_features, num_patches_per_batch) + ] + + def _process_image_input( + self, + image_input: LlavaNextImageInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if image_input["type"] == "image_embeds": + return [image_input["data"]] + + patch_embeddings = self._process_image_pixels(image_input) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = len(image_input["data"]) + vision_config = self.config.vision_config + default_height = default_width = vision_config.image_size + image_sizes = torch.as_tensor([[default_height, default_width] + for _ in range(batch_size)]) + + return [ + self._merge_image_patch_embeddings(image_sizes[i], + patch_features_batch, + strategy="spatial_unpad") + for i, patch_features_batch in enumerate(patch_embeddings) + ] + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LlaVA-NeXT. + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted image embeddings. + + Concretely, consider a text prompt: + `"A chat between a curious human and an artificial intelligence + assistant. The assistant gives helpful, detailed, and polite answers to + the human's questions. + USER: \\nWhat is shown in this image? ASSISTANT:"`. + + Tokenizer outputs: + `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, + 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, + 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, + 29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799, + 9047, 13566, 29901]`. + + To reserve space in KV cache, we have to insert placeholder tokens + before they are inputted to the model, so the input processor prepends + additional image tokens (denoted as `32000`), resulting in: + `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255, + 29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568, + 6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901, + 29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, + 319, 1799, 9047, 13566, 29901]`. + + Unlike in LLaVA-1.5, the number of image tokens inputted to the language + model depends on the original size of the input image. Including the + original image token in the input, the required number of image tokens + is given by :func:`get_llava_next_image_feature_size`. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values: The pixels in each grid patch for each input image. + image_sizes: The original `(height, width)` for each input image. + + See also: + :class:`LlavaNextImageInputs` + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py new file mode 100644 index 0000000..e10c1f9 --- /dev/null +++ b/vllm/model_executor/models/llava_next_video.py @@ -0,0 +1,465 @@ +import math +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import numpy as np +import torch +import torch.nn as nn +from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, + SiglipVisionConfig) + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.models.clip import CLIPVisionModel +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip) +from .utils import (AutoWeightsLoader, init_vllm_registered_model, + merge_multimodal_embeddings) + +# For profile run +_MAX_FRAMES_PER_VIDEO = 32 +_MAX_NUM_VIDEOS = 1 + + +class LlavaNextVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size, num_frames, num_channels, height, width)` + + Note that `num_frames` may be different for each batch, in which case + the data is passed as a list instead of a batched tensor. + + Note that it only supports one video input for one batch. + """ + + +def get_llava_next_video_frame_feature_size( + hf_config: LlavaNextVideoConfig) -> int: + # Support both CLIPVisionConfig and SiglipVisionConfig + image_size = hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + spatial_pool_stride = hf_config.spatial_pool_stride + + return int((image_size / patch_size / spatial_pool_stride)**2) + + +def _get_max_llm_tokens(ctx: InputContext) -> int: + """ + Calculated from the maximum video frames under the context length + constraints of the language model. + """ + hf_text_config = ctx.model_config.hf_text_config + model_config = ctx.model_config + max_tokens = model_config.max_model_len + rope_scaling = model_config.rope_scaling + + if rope_scaling: + rope_scaling_factor = hf_text_config.rope_scaling["factor"] + else: + rope_scaling_factor = 1 + + max_tokens *= rope_scaling_factor + + return max_tokens + + +def get_max_llava_next_video_tokens(ctx: InputContext) -> int: + # Currently set to 32 frames + # TODO: max_tokens = _get_max_llm_tokens(ctx) + hf_config = ctx.get_hf_config(LlavaNextVideoConfig) + tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config) + return _MAX_FRAMES_PER_VIDEO * tokens_per_frame + + +def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(LlavaNextVideoConfig) + vision_config = hf_config.vision_config + + # TODO: support multiple videos + num_videos = mm_counts["video"] + if num_videos != _MAX_NUM_VIDEOS: + raise NotImplementedError( + f"Only {_MAX_NUM_VIDEOS} videos are supported") + + # TODO: support configuring the number of frames + frames_per_video = _MAX_FRAMES_PER_VIDEO + # num_images = num_videos * frames_per_video + + # fills the sequence with as longer video data as possible + tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config) + video_feature_size = frames_per_video * tokens_per_frame + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + num_videos, + image_token_id=hf_config.video_token_index, + image_feature_size_override=video_feature_size, + ) + + pil_frame = dummy_image_for_clip(vision_config, num_images=1) + np_frame = np.array(pil_frame["image"]) + mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) + mm_data = {"video": mm_data_per_video} + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + num_videos, + image_token_id=hf_config.video_token_index, + image_feature_size_override=video_feature_size, + ) + + pil_frame = dummy_image_for_siglip(vision_config, num_images=1) + np_frame = np.array(pil_frame["image"]) + mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) + mm_data = {"video": mm_data_per_video} + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_for_llava_next_video(ctx: InputContext, + llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "video" not in multi_modal_data: + return llm_inputs + video_data = multi_modal_data["video"] + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaNextVideoConfig) + vision_config = hf_config.vision_config + + if isinstance(video_data, np.ndarray): + # Supports both CLIP and Siglip + num_frames = video_data.shape[0] + frame_feature_size = \ + get_llava_next_video_frame_feature_size(hf_config) + video_feature_size = num_frames * frame_feature_size + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=hf_config.video_token_index, + repeat_count=video_feature_size, + ) + + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + elif is_list_of(video_data, np.ndarray): + raise NotImplementedError( + "Processing multiple videos is not supported") + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaNextVideoConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +# adopted from transformers modeling_llava_next_video.py +class LlavaNextVideoPooler(nn.Module): + + def __init__(self, config): + super().__init__() + + mode = config.spatial_pool_mode + stride = config.spatial_pool_stride + image_size = config.vision_config.image_size + patch_size = config.vision_config.patch_size + self.image_size = image_size // patch_size**2 + + if mode == "average": + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride) + elif mode == "max": + self.pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + else: + # TODO: Support Conv2d pooling layer, need to load weights + raise ValueError( + f"Unknown pooling mode: {mode}. Expected [`average`, `max`]") + + def forward(self, image_features): + ori_width = int( + math.sqrt(image_features.shape[1] * self.image_size // + self.image_size)) + ori_height = int(ori_width * self.image_size // self.image_size) + + batch_size, _, dim = image_features.shape + image_features_spatial = image_features \ + .view(batch_size, ori_height, ori_height, dim) \ + .permute(0, 3, 1, 2) + image_features_spatial = self.pool(image_features_spatial) + + return image_features_spatial.flatten(2).transpose(1, 2).contiguous() + + +class LlavaNextMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, text_hidden_size: int, + projector_hidden_act: str): + super().__init__() + + self.linear_1 = nn.Linear(vision_hidden_size, + text_hidden_size, + bias=True) + self.act = get_act_fn(projector_hidden_act) + self.linear_2 = nn.Linear(text_hidden_size, + text_hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_input_mapper("video") +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "video", get_max_llava_next_video_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video) +class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, + config: LlavaNextVideoConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # Initialize the vision tower only up to the required feature layer + self.vision_tower = _init_vision_tower(config) + self.vision_resampler = LlavaNextVideoPooler(config) + self.multi_modal_projector = LlavaNextMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + text_hidden_size=config.text_config.hidden_size, + projector_hidden_act=config.projector_hidden_act) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _validate_video_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[2:]) + + if actual_dims != expected_dims: + expected_expr = ("num_frames", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values in each video frame " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[LlavaNextVideoPixelInputs]: + """ + A legal video input should have the following dimensions: + { + "pixel_values_videos" : + List[b, Tensor(nb_frames, nb_channels, height, width)] + } + """ + pixel_values = kwargs.pop("pixel_values_videos", None) + + if pixel_values is None: + return None + + if not (is_list_of(pixel_values, + (torch.Tensor)) # different shape videos + or isinstance(pixel_values, + torch.Tensor)): # same shape videos + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return LlavaNextVideoPixelInputs( + type="pixel_values_videos", + data=pixel_values, + ) + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _video_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + image_features = self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + image_features = self.vision_resampler(image_features) + image_features = self.multi_modal_projector(image_features) + return image_features + + def _process_video_pixels(self, inputs: LlavaNextVideoPixelInputs): + assert self.vision_tower is not None + + video_pixels = inputs["data"] + + if isinstance(video_pixels, torch.Tensor): + # TODO: support multiple videos per input + b, num_videos, num_frames, c, h, w = video_pixels.shape + assert (num_videos == 1) + stacked_pixels = video_pixels.view(b * num_videos * num_frames, c, + h, w) + stacked_embeddings = self._video_pixels_to_features( + self.vision_tower, stacked_pixels) + return stacked_embeddings.view(b, num_frames, + *stacked_embeddings.shape[1:]) + + elif is_list_of(video_pixels, torch.Tensor): + frames_per_videos = [v.shape[0] for v in video_pixels] + stacked_pixels = torch.cat(video_pixels, dim=0) + stacked_embeddings = self._video_pixels_to_features( + self.vision_tower, stacked_pixels) + return torch.split(stacked_embeddings, frames_per_videos, dim=0) + + else: + raise ValueError( + f"Unsupported type of video input {type(video_pixels)}") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LlaVA-NeXT-Video. + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values_videos: Pixels in each frames for each input videos. + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + video_input = self._parse_and_validate_video_input(**kwargs) + if video_input is not None: + video_embeddings = self._process_video_pixels(video_input) + inputs_embeds = self.language_model \ + .model.get_input_embeddings(input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, video_embeddings, + self.config.video_token_index) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + # This model doesn't support images for now + ignore_unexpected_prefixes=["image_newline"], + ) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py new file mode 100644 index 0000000..46e97e7 --- /dev/null +++ b/vllm/model_executor/models/llava_onevision.py @@ -0,0 +1,875 @@ +import math +from functools import cached_property +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from transformers import (CLIPVisionConfig, LlavaOnevisionConfig, + SiglipVisionConfig) +from transformers.models.llava_onevision.modeling_llava_onevision import ( + get_anyres_image_grid_shape, unpad_image) +from typing_extensions import NotRequired + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, + dummy_video_for_clip, get_clip_image_feature_size, + get_clip_patch_grid_length, input_processor_for_clip) +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, + dummy_video_for_siglip, get_siglip_image_feature_size, + get_siglip_patch_grid_length, input_processor_for_siglip) +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + +# Result in the max possible feature size (2x2 grid of 336x336px tiles) +MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 + +# For profile run +_MAX_FRAMES_PER_VIDEO = 16 +_MAX_NUM_VIDEOS = 1 + + +class LlavaOnevisionVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: `(batch_size, num_frames, num_channels, height, width)` + + Note that `num_frames` may be different for each batch, in which case + the data is passed as a list instead of a batched tensor. + + Note that it only supports one video input for one batch. + """ + + +class LlavaOnevisionImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: NotRequired[torch.Tensor] + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class LlavaOnevisionImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +LlavaOnevisionImageInputs = Union[LlavaOnevisionImagePixelInputs, + LlavaOnevisionImageEmbeddingInputs] + +LlavaOnevisionMultiInputs = Union[LlavaOnevisionImageInputs, + LlavaOnevisionVideoPixelInputs] + + +def _get_llava_onevision_image_unppaded_feature_size(height, width, patches, + scale_height, + scale_width): + current_height = patches * scale_height + current_width = patches * scale_width + + original_aspect_ratio = width / height + current_aspect_ratio = current_width / current_height + if original_aspect_ratio > current_aspect_ratio: + new_height = int(height * (current_width / width)) + padding = (current_height - new_height) // 2 + current_height -= padding * 2 + else: + new_width = int(width * (current_height / height)) + padding = (current_width - new_width) // 2 + current_width -= padding * 2 + + unpadded_features = current_height * current_width + newline_features = current_height + + ratio = math.sqrt(current_height * current_width / (9 * patches**2)) + if ratio > 1.1: + unpadded_features = int(current_height // ratio) * int( + current_width // ratio) + newline_features = int(current_height // ratio) + + return (unpadded_features, newline_features) + + +def get_llava_onevision_image_feature_size( + hf_config: LlavaOnevisionConfig, + *, + input_height: int, + input_width: int, +) -> int: + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + num_patches = get_clip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, + ) + base_feature_size = get_clip_image_feature_size(vision_config) + elif isinstance(vision_config, SiglipVisionConfig): + num_patches = get_siglip_patch_grid_length( + image_size=vision_config.image_size, + patch_size=vision_config.patch_size, + ) + base_feature_size = get_siglip_image_feature_size(vision_config) + else: + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + strategy = hf_config.vision_feature_select_strategy + if strategy == "default": + base_feature_size -= 1 + elif strategy == "full": + pass + else: + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_size=(input_height, input_width), + grid_pinpoints=hf_config.image_grid_pinpoints, + patch_size=vision_config.image_size, + ) + + ( + unpadded_feature_size, + newline_feature_size, + ) = _get_llava_onevision_image_unppaded_feature_size( + input_height, input_width, num_patches, num_patch_height, + num_patch_width) + + return unpadded_feature_size + newline_feature_size + base_feature_size + + +def get_max_llava_onevision_image_tokens(ctx: InputContext): + return get_llava_onevision_image_feature_size( + ctx.get_hf_config(LlavaOnevisionConfig), + input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + ) + + +def get_llava_onevision_video_frame_feature_size( + hf_config: LlavaOnevisionConfig) -> int: + # Support both CLIPVisionConfig and SiglipVisionConfig + image_size = hf_config.vision_config.image_size + patch_size = hf_config.vision_config.patch_size + spatial_pool_stride = hf_config.spatial_pool_stride if hasattr( + hf_config, "spatial_pool_stride") else 2 + + height = width = image_size // patch_size + return math.ceil(height / spatial_pool_stride) * math.ceil( + width / spatial_pool_stride) + + +def get_llava_onevision_video_tokens(ctx: InputContext, + num_frames: int) -> int: + hf_config = ctx.get_hf_config(LlavaOnevisionConfig) + + # TODO: support configuring (not supported by HF right now) + num_token_image_newline = 1 + tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config) + video_feature_size = num_frames * tokens_per_frame + num_token_image_newline + + return video_feature_size + + +def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int: + return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO) + + +def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(LlavaOnevisionConfig) + vision_config = hf_config.vision_config + + # TODO: support multiple videos + num_videos = mm_counts["video"] + if num_videos > _MAX_NUM_VIDEOS: + raise NotImplementedError( + f"Only {_MAX_NUM_VIDEOS} videos are supported") + + # TODO: support configuring the number of frames + num_frames = _MAX_FRAMES_PER_VIDEO + video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) + + if isinstance(vision_config, CLIPVisionConfig): + seq_data = dummy_seq_data_for_clip( + vision_config, + seq_len, + num_videos, + image_token_id=hf_config.video_token_index, + image_feature_size_override=video_feature_size, + ) + + mm_data = dummy_video_for_clip(vision_config, num_frames=num_frames) + return seq_data, mm_data + elif isinstance(vision_config, SiglipVisionConfig): + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + num_videos, + image_token_id=hf_config.video_token_index, + image_feature_size_override=video_feature_size, + ) + + mm_data = dummy_video_for_siglip(vision_config, num_frames=num_frames) + return seq_data, mm_data + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_when_multimodal_input_image(ctx: InputContext, + llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaOnevisionConfig) + vision_config = hf_config.vision_config + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + width, height = image_data.size + + image_feature_size = get_llava_onevision_image_feature_size( + hf_config, + input_height=height, + input_width=width, + ) + elif is_list_of(image_data, Image.Image): + image_feature_size = [ + get_llava_onevision_image_feature_size(hf_config, + input_height=img.height, + input_width=img.width) + for img in image_data + ] + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + vision_config = hf_config.vision_config + + if isinstance(vision_config, CLIPVisionConfig): + return input_processor_for_clip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return input_processor_for_siglip( + model_config, + vision_config, + llm_inputs, + image_token_id=hf_config.image_token_index, + image_feature_size_override=image_feature_size, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_when_multimodal_input_video(ctx: InputContext, + llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "video" not in multi_modal_data: + return llm_inputs + video_data = multi_modal_data["video"] + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(LlavaOnevisionConfig) + vision_config = hf_config.vision_config + + if isinstance(video_data, np.ndarray): + # Supports both CLIP and Siglip + num_frames = video_data.shape[0] + video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=hf_config.video_token_index, + repeat_count=video_feature_size, + ) + + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + elif is_list_of(video_data, np.ndarray): + raise NotImplementedError( + "Processing multiple videos is not supported") + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +def input_processor_for_llava_onevision(ctx: InputContext, + llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or ("video" not in multi_modal_data + and "image" not in multi_modal_data): + return llm_inputs + if "image" in multi_modal_data: + return input_processor_when_multimodal_input_image(ctx, llm_inputs) + if "video" in multi_modal_data: + return input_processor_when_multimodal_input_video(ctx, llm_inputs) + + msg = "Unsupported multi data type" + raise NotImplementedError(msg) + + +def _init_vision_tower(hf_config: LlavaOnevisionConfig): + vision_config = hf_config.vision_config + + # Initialize the vision tower only up to the required feature layer + vision_feature_layer = hf_config.vision_feature_layer + if vision_feature_layer < 0: + num_hidden_layers = hf_config.vision_config.num_hidden_layers \ + + vision_feature_layer + 1 + else: + num_hidden_layers = vision_feature_layer + 1 + + if isinstance(vision_config, CLIPVisionConfig): + return CLIPVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + elif isinstance(vision_config, SiglipVisionConfig): + return SiglipVisionModel( + vision_config, + num_hidden_layers_override=num_hidden_layers, + ) + + msg = f"Unsupported vision config: {type(vision_config)}" + raise NotImplementedError(msg) + + +class LlavaOnevisionMultiModalProjector(nn.Module): + + def __init__(self, config: LlavaOnevisionConfig): + super().__init__() + + self.linear_1 = nn.Linear(config.vision_config.hidden_size, + config.text_config.hidden_size, + bias=True) + self.act = get_act_fn(config.projector_hidden_act) + self.linear_2 = nn.Linear(config.text_config.hidden_size, + config.text_config.hidden_size, + bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_input_mapper("video") +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "image", get_max_llava_onevision_image_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "video", get_max_llava_onevision_video_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision) +@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision) +class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, + config: LlavaOnevisionConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + # Initialize the vision tower only up to the required feature layer + self.vision_tower = _init_vision_tower(config) + self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) + + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_image_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[LlavaOnevisionImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return LlavaOnevisionImagePixelInputs( + type="pixel_values", + data=self._validate_image_pixel_values( + flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True)), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}") + + return LlavaOnevisionImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _validate_video_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[2:]) + + if actual_dims != expected_dims: + expected_expr = ("num_frames", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values in each video frame " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_video_input( + self, + **kwargs: object) -> Optional[LlavaOnevisionVideoPixelInputs]: + """ + A legal video input should have the following dimensions: + { + "pixel_values_videos" : + List[b, Tensor(nb_frames, nb_channels, height, width)] + } + """ + pixel_values = kwargs.pop("pixel_values_videos", None) + + if pixel_values is None: + return None + + if not (is_list_of(pixel_values, + (torch.Tensor)) # different shape videos + or isinstance(pixel_values, + torch.Tensor)): # same shape videos + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + return LlavaOnevisionVideoPixelInputs( + type="pixel_values_videos", + data=pixel_values, + ) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + if "pixel_values" in kwargs: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + + if "pixel_values_videos" in kwargs: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + + return modalities + + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + image_features = vision_tower(pixel_values) + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + + # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py + def _merge_image_patch_embeddings(self, + image_size: torch.Tensor, + patch_embeddings: torch.Tensor, + *, + image_newline=None, + vision_aspect_ratio="anyres_max_9", + strategy: str) -> torch.Tensor: + if strategy == "flat": + return patch_embeddings.flatten(0, 1) + + if strategy.startswith("spatial"): + height = width = self.config.vision_config.image_size \ + // self.config.vision_config.patch_size + + base_patch_embeds = patch_embeddings[0] + if height * width != base_patch_embeds.shape[0]: + raise ValueError( + "The number of patches is not consistent with the " + "image size.") + + if patch_embeddings.shape[0] > 1: + other_patch_embeds = patch_embeddings[1:] + + # Move to CPU to avoid floating-point errors + orig_height, orig_width = image_size.tolist() + + # image_aspect_ratio == "anyres" + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + (orig_height, orig_width), + self.config.image_grid_pinpoints, + self.config.vision_config.image_size, + ) + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ + .view(num_patch_height, num_patch_width, height, width, -1) + + if "unpad" in strategy: + other_patch_embeds = other_patch_embeds \ + .permute(4, 0, 2, 1, 3).contiguous() \ + .flatten(1, 2).flatten(2, 3) + other_patch_embeds = unpad_image(other_patch_embeds, + (orig_height, orig_width)) + max_num_patches = int( + vision_aspect_ratio.removeprefix("anyres_max_")) + channels, curr_height, curr_width = other_patch_embeds.shape + ratio = math.sqrt(curr_height * curr_width / + (max_num_patches * height**2)) + if ratio > 1.1: + other_patch_embeds = other_patch_embeds[None] + other_patch_embeds = nn.functional.interpolate( + other_patch_embeds, [ + int(curr_height // ratio), + int(curr_width // ratio) + ], + mode="bilinear")[0] + if image_newline is not None: + other_patch_embeds = torch.cat( + ( + other_patch_embeds, + image_newline[:, None, None] \ + .expand(*other_patch_embeds.shape[:-1], 1) \ + .to(other_patch_embeds.device), + ), + dim=-1) + other_patch_embeds = other_patch_embeds \ + .flatten(1, 2).transpose(0, 1) + else: + other_patch_embeds = other_patch_embeds \ + .permute(0, 2, 1, 3, 4).contiguous() \ + .flatten(0, 3) + + merged_patch_embeddings = torch.cat( + (base_patch_embeds, other_patch_embeds), dim=0) + else: + if "unpad" in strategy: + merged_patch_embeddings = torch.cat( + (base_patch_embeds, + self.image_newline[None] \ + .to(base_patch_embeds.device) + ), dim=0) + else: + merged_patch_embeddings = base_patch_embeds + + return merged_patch_embeddings + + raise ValueError(f"Unexpected patch merge strategy: {strategy}") + + def _process_image_pixels( + self, + inputs: LlavaOnevisionImagePixelInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + assert self.vision_tower is not None + + pixel_values = inputs["data"] + + if isinstance(pixel_values, torch.Tensor): + b, num_patches, c, h, w = pixel_values.shape + stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + stacked_patch_embeddings = self.multi_modal_projector( + stacked_image_features) + + return stacked_patch_embeddings.view( + b, num_patches, *stacked_patch_embeddings.shape[1:]) + + num_patches_per_batch = [v.shape[0] for v in pixel_values] + stacked_pixel_values = torch.cat(pixel_values) + stacked_image_features = self._image_pixels_to_features( + self.vision_tower, stacked_pixel_values) + + return [ + self.multi_modal_projector(image_features) for image_features in + torch.split(stacked_image_features, num_patches_per_batch) + ] + + def _process_image_input( + self, + image_input: LlavaOnevisionImageInputs, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if image_input["type"] == "image_embeds": + return [image_input["data"]] + + patch_embeddings = self._process_image_pixels(image_input) + + image_sizes = image_input.get("image_sizes") + if image_sizes is None: + batch_size = len(image_input["data"]) + vision_config = self.config.vision_config + default_height = default_width = vision_config.image_size + image_sizes = torch.as_tensor([[default_height, default_width] + for _ in range(batch_size)]) + + return [ + self._merge_image_patch_embeddings( + image_sizes[i], + patch_features_batch, + image_newline=self.image_newline, + strategy="spatial_unpad") + for i, patch_features_batch in enumerate(patch_embeddings) + ] + + def _video_pixels_to_features( + self, + vision_tower: Union[CLIPVisionModel, SiglipVisionModel], + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the vision tower + b, num_videos, frames, c, h, w = pixel_values.shape + assert (num_videos == _MAX_NUM_VIDEOS) + pixel_values = pixel_values.reshape(b * num_videos * frames, c, h, w) + video_features = vision_tower(pixel_values) + video_features = self._select_image_features( + video_features, + strategy=self.config.vision_feature_select_strategy, + ) + video_features = self.multi_modal_projector(video_features) + video_features = self.apply_pooling(video_features) + video_features = video_features.reshape( + b, frames * video_features.shape[1], -1) + image_newline = self.image_newline[None, None, :].repeat(b, 1, 1).to( + video_features.device) + video_features = torch.cat((video_features, image_newline), dim=1) + video_features = video_features.flatten(0, 1) + + return video_features + + def _process_video_pixels(self, inputs: LlavaOnevisionVideoPixelInputs): + assert self.vision_tower is not None + + video_pixels = inputs["data"] + + # TODO: support multiple videos per input + if isinstance(video_pixels, torch.Tensor): + stacked_embeddings = self._video_pixels_to_features( + self.vision_tower, video_pixels) + return stacked_embeddings + else: + raise ValueError( + f"Unsupported type of video input {type(video_pixels)}") + + def apply_pooling(self, image_features, stride=2): + vision_config = self.config.vision_config + height = width = vision_config.image_size // vision_config.patch_size + batch_frames, _, dim = image_features.shape + image_features = image_features.view(batch_frames, height, width, -1) + image_features = image_features.permute(0, 3, 1, 2) + + # TODO support other pooling types config + height, width = image_features.shape[2:] + scaled_shape = [math.ceil(height / stride), math.ceil(width / stride)] + image_feature = nn.functional.interpolate(image_features, + size=scaled_shape, + mode='bilinear') + image_feature = image_feature.permute(0, 2, 3, 1) + image_feature = image_feature.view(batch_frames, -1, dim) + return image_feature + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for LlaVA-Onevision. + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + pixel_values_videos: Pixels in each frames for each input videos. + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if modalities: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + if "images" in modalities: + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + if "videos" in modalities: + video_input = modalities["videos"] + video_embeddings = self._process_video_pixels(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, video_embeddings, + self.config.video_token_index) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py new file mode 100644 index 0000000..1112a21 --- /dev/null +++ b/vllm/model_executor/models/mamba.py @@ -0,0 +1,499 @@ +# coding=utf-8 +"""PyTorch MAMBA model.""" +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import MambaConfig + +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + composed_weight_loader, default_weight_loader, sharded_weight_loader) +from vllm.model_executor.models.interfaces import (HasInnerState, + IsAttentionFree) +from vllm.model_executor.models.mamba_cache import MambaCacheManager +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors +from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, + _get_graph_batch_size) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +@dataclass +class MambaCacheParams: + is_prompt: bool = False + conv_state: torch.Tensor = torch.Tensor() + ssm_state: torch.Tensor = torch.Tensor() + + +# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer +class MambaMixer(nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__(self, config: MambaConfig, layer_idx): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel + self.intermediate_size = config.intermediate_size + self.time_step_rank = int(config.time_step_rank) + + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.intermediate_size, + bias=config.use_conv_bias, + ) + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `set_weight_attrs` + # doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + self.in_proj = MergedColumnParallelLinear(self.hidden_size, + [self.intermediate_size] * 2, + bias=config.use_bias) + # selective projection used to make dt, B and C input dependent + self.x_proj = RowParallelLinear( + self.intermediate_size, + self.time_step_rank + self.ssm_state_size * 2, + bias=False, + ) + # time step projection (discretization) - + # In the forward we need to apply dt_proj without the bias, + # as the bias is added in the selective scan kernel. + self.dt_proj = ColumnParallelLinear(self.time_step_rank, + self.intermediate_size, + bias=True, + skip_bias_add=True) + + tp_size = get_tensor_model_parallel_world_size() + self.A = nn.Parameter( + torch.empty( + self.intermediate_size // tp_size, + self.ssm_state_size, + dtype=torch.float32, + )) + self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size)) + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float())) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + + self.out_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=config.use_bias, + input_is_parallel=True, + ) + self.activation = config.hidden_act + + def forward(self, hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, conv_state: torch.Tensor, + ssm_state: torch.Tensor): + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) + hidden_states, gate = projected_states.chunk(2, dim=-2) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), + self.conv1d.weight.size(2)) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + hidden_states = causal_conv1d_fn( + hidden_states, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + hidden_states = causal_conv1d_update( + hidden_states.transpose(0, 1), + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + ) + hidden_states = hidden_states.transpose(0, 1) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] + + time_step, B, C = torch.split( + ssm_parameters, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1, + ) + + # Note that Jamba normalizes B, C, and time_step here but Mamba doesn't. + + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = (self.dt_proj.bias.float() if hasattr( + self.dt_proj, "bias") else None) + + if attn_metadata.query_start_loc is not None \ + and attn_metadata.context_lens_tensor is not None: + scan_outputs = selective_scan_fn( + hidden_states, + ssm_state, + discrete_time_step, + self.A, + B.transpose(-2, -1), + C.transpose(-2, -1), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + has_initial_state=attn_metadata.context_lens_tensor > 0, + query_start_loc=attn_metadata.query_start_loc) + else: + scan_outputs = selective_state_update( + ssm_state, + hidden_states.transpose(0, 1), + discrete_time_step.transpose(0, 1), + self.A, + B, + C, + self.D, + gate.transpose(0, 1), + time_proj_bias, + dt_softplus=True, + ) + scan_outputs = scan_outputs.transpose(0, 1) + + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(-2, + -1))[0] + return contextualized_states + + +class MambaMLP(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + hidden_act = config.hidden_act + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MambaDecoderLayer(nn.Module): + + def __init__(self, + config: MambaConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.mixer = MambaMixer(config, layer_idx) + + self.feed_forward = MambaMLP(config, quant_config=quant_config) + self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + **kwargs, + ): + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer(hidden_states, attn_metadata, conv_state, + ssm_state) + # Fully Connected + hidden_states, residual = self.pre_ff_layernorm( + hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class MambaModel(nn.Module): + + def __init__( + self, + config: MambaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embeddings = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + decoder_layers = [] + for i in range(config.num_hidden_layers): + decoder_layers.append( + MambaDecoderLayer(config, + layer_idx=i, + cache_config=cache_config, + quant_config=quant_config)) + self.layers = nn.ModuleList(decoder_layers) + self.norm_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + conv_state: torch.Tensor, + ssm_state: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.embeddings(input_ids) + residual = None + + for i in range(len(self.layers)): + layer = self.layers[i] + current_ssm_state = ssm_state[i] + current_conv_state = conv_state[i] + + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + conv_state=current_conv_state, + ssm_state=current_ssm_state, + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + + return hidden_states + + +class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embeddings": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MambaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, + ) -> None: + assert not cache_config.enable_prefix_caching, \ + "Mamba does not support prefix caching" + + super().__init__() + self.config = config + self.scheduler_config = scheduler_config + self.backbone = MambaModel(config, + cache_config=cache_config, + quant_config=quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + + self.lm_head = self.backbone.embeddings + + # Used to track and store by the Mamba cache between steps. + self.mamba_cache: Optional[MambaCacheManager] = None + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs): + if self.mamba_cache is None: + max_batch_size = (_get_graph_batch_size( + self.scheduler_config.max_num_seqs) if self.scheduler_config + else max(_BATCH_SIZES_TO_CAPTURE) + 2) + self.mamba_cache = MambaCacheManager( + self.lm_head.weight.dtype, self.config.num_hidden_layers, + max_batch_size, *self._get_mamba_cache_shape()) + + mamba_cache_tensors = self.mamba_cache.current_run_tensors( + input_ids, attn_metadata, **kwargs) + + hidden_states = self.backbone(input_ids, positions, kv_caches, + attn_metadata, mamba_cache_tensors[0], + mamba_cache_tensors[1]) + + return hidden_states + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs( + input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def _get_mamba_cache_shape( + self) -> Tuple[Tuple[int, int], Tuple[int, int]]: + world_size = get_tensor_model_parallel_world_size() + conv_state_shape = ( + self.config.intermediate_size // world_size, + self.config.conv_kernel - 1, + ) + temporal_state_shape = ( + self.config.intermediate_size // world_size, + self.config.state_size, + ) + return conv_state_shape, temporal_state_shape + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + if "A_log" in name: + name = name.replace("A_log", "A") + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mamba_cache.py b/vllm/model_executor/models/mamba_cache.py new file mode 100644 index 0000000..8d1ba37 --- /dev/null +++ b/vllm/model_executor/models/mamba_cache.py @@ -0,0 +1,222 @@ +from typing import Dict, List, Optional + +import torch + +from vllm.attention.backends.abstract import AttentionMetadata + + +class MambaCacheManager: + + def __init__(self, dtype, num_mamba_layers, max_batch_size, + conv_state_shape, temporal_state_shape): + + conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + conv_state_shape, + dtype=dtype, + device="cuda") + temporal_state = torch.empty(size=(num_mamba_layers, max_batch_size) + + temporal_state_shape, + dtype=dtype, + device="cuda") + + self.mamba_cache = (conv_state, temporal_state) + + # Maps between the request id and a dict that maps between the seq_id + # and its index inside the self.mamba_cache + self.mamba_cache_indices_mapping: Dict[str, Dict[int, int]] = {} + + def current_run_tensors(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, **kwargs): + """ + Return the tensors for the current run's conv and ssm state. + """ + if "seqlen_agnostic_capture_inputs" not in kwargs: + # We get here only on Prefill/Eager mode runs + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + finished_requests_ids = kwargs["finished_requests_ids"] + + self._release_finished_requests(finished_requests_ids) + mamba_cache_tensors = self._prepare_current_run_mamba_cache( + request_ids_to_seq_ids, finished_requests_ids) + + else: + # CUDA graph capturing runs + mamba_cache_tensors = kwargs["seqlen_agnostic_capture_inputs"] + + return mamba_cache_tensors + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + """ + Copy the relevant Mamba cache into the CUDA graph input buffer + that was provided during the capture runs + (JambaForCausalLM.mamba_gc_cache_buffer). + """ + assert all( + key in kwargs + for key in ["request_ids_to_seq_ids", "finished_requests_ids"]) + finished_requests_ids = kwargs["finished_requests_ids"] + request_ids_to_seq_ids = kwargs["request_ids_to_seq_ids"] + + self._release_finished_requests(finished_requests_ids) + self._prepare_current_run_mamba_cache(request_ids_to_seq_ids, + finished_requests_ids) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + """ + Provide the CUDA graph capture runs with a buffer in adjusted size. + The buffer is used to maintain the Mamba Cache during the CUDA graph + replay runs. + """ + return tuple(buffer[:, :batch_size] for buffer in self.mamba_cache) + + def _swap_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, [to_index,from_index]] = \ + cache_t[:, [from_index,to_index]] + + def _copy_mamba_cache(self, from_index: int, to_index: int): + assert len(self.mamba_cache) > 0 + for cache_t in self.mamba_cache: + cache_t[:, to_index].copy_(cache_t[:, from_index], + non_blocking=True) + + def _move_out_if_already_occupied(self, index: int, + all_occupied_indices: List[int]): + if index in all_occupied_indices: + first_free_index = self._first_free_index_in_mamba_cache() + # In case occupied, move the occupied to a new empty block + self._move_cache_index_and_mappings(from_index=index, + to_index=first_free_index) + + def _assign_seq_id_to_mamba_cache_in_specific_dest(self, cur_rid: str, + seq_id: int, + destination_index: int): + """ + Assign (req_id,seq_id) pair to a `destination_index` index, if + already occupied, move the occupying index to a free index. + """ + all_occupied_indices = self._get_all_occupied_indices() + if cur_rid not in self.mamba_cache_indices_mapping: + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + self.mamba_cache_indices_mapping[cur_rid] = { + seq_id: destination_index + } + elif seq_id not in (seq_ids2indices := + self.mamba_cache_indices_mapping[cur_rid]): + # parallel sampling , where n > 1, assume prefill have + # already happened now we only need to copy the already + # existing cache into the siblings seq_ids caches + self._move_out_if_already_occupied( + index=destination_index, + all_occupied_indices=all_occupied_indices) + index_exists = list(seq_ids2indices.values())[0] + # case of decoding n>1, copy prefill cache to decoding indices + self._copy_mamba_cache(from_index=index_exists, + to_index=destination_index) + self.mamba_cache_indices_mapping[cur_rid][ + seq_id] = destination_index + else: + # already exists + cache_index_already_exists = self.mamba_cache_indices_mapping[ + cur_rid][seq_id] + if cache_index_already_exists != destination_index: + # In case the seq id already exists but not in + # the right destination, swap it with what's occupying it + self._swap_pair_indices_and_mappings( + from_index=cache_index_already_exists, + to_index=destination_index) + + def _prepare_current_run_mamba_cache( + self, request_ids_to_seq_ids: Dict[str, list[int]], + finished_requests_ids: List[str]): + running_indices = [] + request_ids_to_seq_ids_flatten = [ + (req_id, seq_id) + for req_id, seq_ids in request_ids_to_seq_ids.items() + for seq_id in seq_ids + ] + batch_size = len(request_ids_to_seq_ids_flatten) + for dest_index, (request_id, + seq_id) in enumerate(request_ids_to_seq_ids_flatten): + if request_id in finished_requests_ids: + # Do not allocate cache index for requests that run + # and finish right after + continue + self._assign_seq_id_to_mamba_cache_in_specific_dest( + request_id, seq_id, dest_index) + running_indices.append(dest_index) + + self._clean_up_first_bs_blocks(batch_size, running_indices) + conv_state = self.mamba_cache[0][:, :batch_size] + temporal_state = self.mamba_cache[1][:, :batch_size] + + return (conv_state, temporal_state) + + def _get_all_occupied_indices(self): + return [ + cache_idx + for seq_ids2indices in self.mamba_cache_indices_mapping.values() + for cache_idx in seq_ids2indices.values() + ] + + def _clean_up_first_bs_blocks(self, batch_size: int, + indices_for_current_run: List[int]): + # move out all of the occupied but currently not running blocks + # outside of the first n blocks + destination_indices = range(batch_size) + max_possible_batch_size = self.mamba_cache[0].shape[1] + for destination_index in destination_indices: + if destination_index in self._get_all_occupied_indices() and \ + destination_index not in indices_for_current_run: + # move not running indices outside of the batch + all_other_indices = list( + range(batch_size, max_possible_batch_size)) + first_avail_index = self._first_free_index_in_mamba_cache( + all_other_indices) + self._swap_indices(from_index=destination_index, + to_index=first_avail_index) + + def _move_cache_index_and_mappings(self, from_index: int, to_index: int): + self._copy_mamba_cache(from_index=from_index, to_index=to_index) + self._update_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_pair_indices_and_mappings(self, from_index: int, to_index: int): + self._swap_mamba_cache(from_index=from_index, to_index=to_index) + self._swap_mapping_index(from_index=from_index, to_index=to_index) + + def _swap_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + elif to_index == index: + seq_ids2index.update({seq_id: from_index}) + + def _update_mapping_index(self, from_index: int, to_index: int): + for seq_ids2index in self.mamba_cache_indices_mapping.values(): + for seq_id, index in seq_ids2index.items(): + if from_index == index: + seq_ids2index.update({seq_id: to_index}) + return + + def _release_finished_requests(self, + finished_seq_groups_req_ids: List[str]): + for req_id in finished_seq_groups_req_ids: + if req_id in self.mamba_cache_indices_mapping: + self.mamba_cache_indices_mapping.pop(req_id) + + def _first_free_index_in_mamba_cache( + self, indices_range: Optional[List[int]] = None) -> int: + assert self.mamba_cache is not None + if indices_range is None: + max_possible_batch_size = self.mamba_cache[0].shape[1] + indices_range = list(range(max_possible_batch_size)) + all_occupied_indices = self._get_all_occupied_indices() + for i in indices_range: + if i not in all_occupied_indices: + return i + raise Exception("Couldn't find a free spot in the mamba cache! This" + "should never happen") diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py new file mode 100644 index 0000000..619a5cd --- /dev/null +++ b/vllm/model_executor/models/medusa.py @@ -0,0 +1,184 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.transformers_utils.configs.medusa import MedusaConfig + + +class ResidualBlock(nn.Module): + + def __init__(self, hidden_size: int, num_layers: int) -> None: + super().__init__() + + self.layers = nn.ModuleList([ + nn.Linear(hidden_size, hidden_size, bias=False) + for _ in range(num_layers) + ]) + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for layer in self.layers: + x = x + self.act(layer(x)) + return x + + +class Medusa(nn.Module): + """This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774 + Reference implementation: https://github.com/FasterDecoding/Medusa + + Differences from reference implementation: + 1. Currently this only supports generating proposals from top-1 tokens. + 2. We have an optional token_map which reduces draft vocab to most + frequently used tokens to give some additional speed-up by reducing + sampling overhead. This is disabled unless the checkpoint file has + explicit token_map tensor and config has an optional attribute + truncated_vocab_size < vocab_size. To use this technique, one has to find + the top-k most frequent tokens in target dataset and add that as a tensor + in the draft checkpoint (using key token_map). Also, the draft config + needs to have truncated_vocab_size (=k) as an attribute.""" + + def __init__(self, config: MedusaConfig, **_) -> None: + super().__init__() + self.config = config + self.blocks = nn.ModuleList([ + ResidualBlock(hidden_size=self.config.hidden_size, + num_layers=self.config.num_hidden_layers) + for _ in range(self.config.num_heads) + ]) + self.orig_vocab_size = config.vocab_size + self.truncated_vocab_size = config.truncated_vocab_size + self.unpadded_vocab_size = self.truncated_vocab_size + + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.truncated_vocab_size, + logit_scale) + + # Token map is a idx to token mapping to reduce the vocab size for + # the draft model. Using smaller vocab size for draft, containing + # only most frequent tokens reduces the speculation overhead. This + # doesn't affect the acceptance rate much and thus gives more speed + # -up. By default, this is disabled and is only used if the EAGLE + # checkpoint file has token_map tensor. + self.token_map = None + + def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: + return [block(hidden_states) for block in self.blocks] + + def compute_logits( + self, hidden_states: List[torch.Tensor], + sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: + logits_lst: List[torch.Tensor] = [] + + for hs, lm_head in zip(hidden_states, self.lm_heads): + _logits = self.logits_processor(lm_head, hs, sampling_metadata) + + if _logits is None: + # _logits should only be None on rank > 0, in which case + # it should remain true for every lm_head + assert len(logits_lst) == 0 + continue + + if self.token_map is None: + logits_lst.append(_logits) + else: + logits_lst.append(-torch.inf * torch.ones( + size=(*_logits.shape[:-1], self.orig_vocab_size), + device=_logits.device, + dtype=_logits.dtype)) + + logits_lst[-1][..., self.token_map] = _logits + + return logits_lst + + def sample( + self, + logits: List[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + logits = torch.stack(logits, dim=0).float() + logprobs = torch.log_softmax(logits, dim=-1) + token_ids = logits.argmax(-1) # support only top-1 for now + probs = torch.softmax(logits, dim=-1) + + token_id_list = [] + token_prob_list = [] + token_logprob_list = [] + + for idx, seq_group in enumerate(sampling_metadata.seq_groups): + token_id_list.append(token_ids[:, seq_group.sample_indices]) + token_prob_list.append(probs[:, seq_group.sample_indices]) + token_logprob_list.append(logprobs[:, seq_group.sample_indices]) + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(sampling_metadata.seq_groups)): + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx].squeeze(1), + logprobs=token_logprob_list[idx].squeeze(1), + sampled_token_ids=token_id_list[idx].squeeze(1), + )) + + return outputs + + def generate_proposals( + self, + previous_hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + return self.sample( + logits=self.compute_logits( + hidden_states=self.forward(previous_hidden_states), + sampling_metadata=sampling_metadata, + ), + sampling_metadata=sampling_metadata, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + + weights_map = {} + + for name, loaded_weight in weights: + name = name.replace("medusa_heads.", "") + + if name == "token_map": + if self.truncated_vocab_size < self.orig_vocab_size: + self.token_map = nn.Parameter(loaded_weight, + requires_grad=False) + elif name in params_dict: + weights_map[name] = loaded_weight + + for name, loaded_weight in weights_map.items(): + if "lm_head" in name and self.token_map is not None and\ + loaded_weight.shape[0] > self.token_map.shape[0]: + + loaded_weight = loaded_weight[self.token_map] + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if self.token_map is not None: + self.token_map.to(device=self.lm_heads[0].weight.device) + + assert (self.truncated_vocab_size + == self.orig_vocab_size) or (self.token_map is not None) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py new file mode 100644 index 0000000..41c2877 --- /dev/null +++ b/vllm/model_executor/models/minicpm.py @@ -0,0 +1,599 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM model compatible with HuggingFace weights.""" +import math +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class MiniCPMMoE(nn.Module): + """A tensor-parallel MoE implementation that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + quant_config=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_size) + + +class MiniCPMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MiniCPMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + # set rope as fp32 instead of bf16 + self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache( + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + orig_dtype = q.dtype + q, k = q.float(), k.float() + q, k = self.rotary_emb(positions, q, k) + q, k = q.to(orig_dtype), k.to(orig_dtype) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + self.hidden_size = config.hidden_size + self.rope_theta = getattr(config, "rope_theta", 10000) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self._init_attn_block() + self._init_ffn_block() + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = MiniCPMAttention( + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + num_kv_heads=self.config.num_key_value_heads, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + ) + + def _init_ffn_block(self): + self.post_attention_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.num_experts = getattr(self.config, "num_experts", 0) + if self.num_experts == 0: + self.mlp = MiniCPMMLP( + hidden_size=self.hidden_size, + intermediate_size=self.config.intermediate_size, + hidden_act=self.config.hidden_act, + quant_config=self.quant_config, + ) + else: + self.mlp = MiniCPMMoE( + num_experts=self.config.num_experts, + top_k=self.config.num_experts_per_tok, + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * \ + (self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)) + + return hidden_states, None + + +class MiniCPMModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.cache_config = cache_config + self.quant_config = quant_config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self._init_layers(prefix, config, cache_config, quant_config) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) + + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + embedding = self.embed_tokens(input_ids) + return embedding * self.config.scale_emb + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + self.cache_config = cache_config + self.quant_config = quant_config + + self.num_experts = getattr(self.config, "num_experts", 0) + self._init_model() + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + self.scale_width = self.config.hidden_size / self.config.dim_model_base + + self.logits_processor = LogitsProcessor(unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def _init_model(self): + self.model = MiniCPMModel(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + hidden_states = hidden_states / self.scale_width + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.num_experts) + for weight_name in ["w1", "w2", "w3"] + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py new file mode 100644 index 0000000..e4aaf95 --- /dev/null +++ b/vllm/model_executor/models/minicpm3.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2024 The ModelBest team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM3 model compatible with HuggingFace weights.""" +from typing import Any, Dict, Optional, Union, List, Tuple +import math + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, + MiniCPMForCausalLM, + MiniCPMModel) +from vllm.sequence import IntermediateTensors +from vllm.distributed import get_pp_group + +from .utils import make_layers + + +class MiniCPM3Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + + tp_size = get_tensor_model_parallel_world_size() + assert self.num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config) + + self.kv_a_proj_with_mqa = ReplicatedLinear(self.hidden_size, + self.kv_lora_rank + + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config) + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config) + + self.rotary_emb = get_rope( + self.qk_rope_head_dim, + rotary_dim=self.qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_local_heads, + self.qk_head_dim, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config) + self.merge_q_kv_a = False + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + long_prompt_offset: torch.Tensor, + long_short_cos_sin_cache: torch.Tensor, + ) -> torch.Tensor: + import ixformer.inference.functions as ixf + if hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16: + if not self.merge_q_kv_a: + self.qkv_weight = torch.cat([self.q_a_proj.weight, self.kv_a_proj_with_mqa.weight], dim=0) + del self.q_a_proj + del self.kv_a_proj_with_mqa + self.merge_q_kv_a = True + q_latent_cache = ixf.linear(hidden_states, self.qkv_weight) + q, latent_cache = q_latent_cache.split([self.q_lora_rank, + self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1) + else: + q, _ = self.q_a_proj(hidden_states) + latent_cache, _ = self.kv_a_proj_with_mqa(hidden_states) + + q = self.q_a_layernorm(q) + q, _ = self.q_b_proj(q) + q = q.view(-1, self.num_local_heads, self.qk_head_dim) + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a) + kv, _ = self.kv_b_proj(kv_a) + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + q_pe, k_pe = ixf.minicpm3_fused_rope( + positions, + long_prompt_offset, + long_short_cos_sin_cache, + q_pe, latent_cache[:, :, self.kv_lora_rank:], + out_query = q[..., self.qk_nope_head_dim:] + ) + + q = q.view(-1, self.num_local_heads * self.qk_head_dim) + + k, v = ixf.minicpm3_fused_copy_kv(k_nope, k_pe, v) + + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim) + new_attn_output = attn_output.new_empty([attn_output.shape[0], attn_output.shape[1], self.v_head_dim]) + new_attn_output[:, :, :] = attn_output[:, :, :self.v_head_dim] + attn_output = new_attn_output.view(-1, self.num_local_heads * self.v_head_dim) + + output, _ = self.o_proj(attn_output) + return output + + +class MiniCPM3DecoderLayer(MiniCPMDecoderLayer): + def __init__(self, config: PretrainedConfig, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None) -> None: + super().__init__(config, cache_config, quant_config) + self.hidden_scale = config.scale_depth / math.sqrt(config.num_hidden_layers) + + def _init_attn_block(self): + self.input_layernorm = RMSNorm(self.config.hidden_size, + eps=self.config.rms_norm_eps) + self.self_attn = MiniCPM3Attention( + config=self.config, + hidden_size=self.hidden_size, + num_heads=self.config.num_attention_heads, + qk_nope_head_dim=self.config.qk_nope_head_dim, + qk_rope_head_dim=self.config.qk_rope_head_dim, + v_head_dim=self.config.v_head_dim, + q_lora_rank=self.config.q_lora_rank, + kv_lora_rank=self.config.kv_lora_rank, + rope_theta=self.rope_theta, + rope_scaling=self.rope_scaling, + max_position_embeddings=self.max_position_embeddings, + cache_config=self.cache_config, + quant_config=self.quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + long_prompt_offset: Optional[torch.Tensor], + long_short_cos_sin_cache: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(residual, hidden_states, self.hidden_scale) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + long_prompt_offset=long_prompt_offset, + long_short_cos_sin_cache=long_short_cos_sin_cache, + ) + + hidden_states, residual = self.post_attention_layernorm(residual, hidden_states, self.hidden_scale) + + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +class MiniCPM3Model(MiniCPMModel): + + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPM3DecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + k = self.layers[self.start_layer].self_attn.rotary_emb.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + long_short_cos_sin_cache = ( + self.layers[self.start_layer].self_attn.rotary_emb.long_short_cos_sin_cache.to(input_ids.device)) + + + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + long_prompt_offset=long_prompt_offset, + long_short_cos_sin_cache=long_short_cos_sin_cache, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, residual = self.norm(residual, hidden_states, self.layers[self.start_layer].hidden_scale) + + return hidden_states + + +class MiniCPM3ForCausalLM(MiniCPMForCausalLM): + packed_modules_mapping = { + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "kv_a_proj_with_mqa", + "q_a_proj", + "q_b_proj", + "kv_b_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + + # `embedding_modules` and `embedding_padding_modules` + # are inherited from MiniCPMForCausalLM + + def _init_model(self): + self.model = MiniCPM3Model(config=self.config, + cache_config=self.cache_config, + quant_config=self.quant_config, + lora_config=self.lora_config) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py new file mode 100644 index 0000000..9ee4dd0 --- /dev/null +++ b/vllm/model_executor/models/minicpmv.py @@ -0,0 +1,1044 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" +import math +import re +from functools import partial +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) + +import torch +import torch.types +from PIL import Image +from torch import nn +from transformers import PretrainedConfig +from typing_extensions import NotRequired + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, + get_2d_sincos_pos_embed) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.minicpm import MiniCPMModel +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.utils import LLMWrapper +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors, SequenceData + +from .idefics2_vision_model import Idefics2VisionTransformer +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .utils import is_pp_missing_parameter + +_KEYS_TO_MODIFY_MAPPING = { + "llm.lm_head": "lm_head", +} + +RawImageType = Union[Image.Image, torch.Tensor] + + +class MiniCPMVRawImageInput(TypedDict): + """Input mapper input with auxiliary data for computing image bounds.""" + image: RawImageType + + # Image bounds token ids in 0-dim scaler tensor. + im_start_id: torch.Tensor + im_end_id: torch.Tensor + slice_start_id: NotRequired[torch.Tensor] + slice_end_id: NotRequired[torch.Tensor] + + +class MiniCPMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: List[torch.Tensor] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class MiniCPMVImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, + MiniCPMVImageEmbeddingInputs] + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +class Resampler2_5(BaseResampler): + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + ) -> None: + super().__init__(num_queries, embed_dim, num_heads, kv_dim, norm_layer) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) + + self.apply(self._init_weights) + + def _set_2d_pos_cache(self, + max_size: Tuple[int, int], + device: torch.types.Device = "cpu") -> None: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + max_size, + version=(2, 5)) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) + self.register_buffer("pos_embed", pos_embed, persistent=False) + + def _adjust_pos_cache(self, tgt_sizes: torch.Tensor, + device: torch.types.Device) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + + if max_h > self.max_size[0] or max_w > self.max_size[1]: + self.max_size = ( + max(max_h, self.max_size[0]), + max(max_w, self.max_size[1]), + ) + self._set_2d_pos_cache(self.max_size, device) + + def forward(self, x: torch.Tensor, + tgt_sizes: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros((bs, max_patch_len), + dtype=torch.bool, + device=device) + + pos_embed = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i].tolist() + pos_embed.append(self.pos_embed[:tgt_h, :tgt_w, :].reshape( + (tgt_h * tgt_w, -1)).to(dtype)) # patches * D + key_padding_mask[i, patch_len[i]:] = True + pos_embed = torch.nn.utils.rnn.pad_sequence(pos_embed, + batch_first=True, + padding_value=0.0).permute( + 1, 0, + 2) # BLD => L * B * D + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + + q = self.ln_q(self.query) # Q * D + + out = self.attn( + self._repeat(q, bs), # Q * B * D + x + pos_embed, # L * B * D + L * B * D + x, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + +def _build_image_input(ctx: InputContext, + image: RawImageType) -> MiniCPMVRawImageInput: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + trust_remote_code=ctx.model_config.trust_remote_code) + if hasattr(tokenizer, "slice_start_id"): + return MiniCPMVRawImageInput( + image=image, + im_start_id=torch.tensor(tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id), + slice_start_id=torch.tensor(tokenizer.slice_start_id), + slice_end_id=torch.tensor(tokenizer.slice_end_id)) + else: + return MiniCPMVRawImageInput( + image=image, + im_start_id=torch.tensor(tokenizer.im_start_id), + im_end_id=torch.tensor(tokenizer.im_end_id)) + + +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return (2, 0) + return (2, 5) + + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + +def get_max_minicpmv_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config() + return getattr(hf_config, "query_num", 64) + + +def dummy_seq_data_for_minicpmv(seq_len: int, num_images: int): + return SequenceData.from_token_counts((0, seq_len)) + + +def dummy_image_for_minicpmv(ctx: InputContext, hf_config: PretrainedConfig, + num_images: int): + width = height = hf_config.image_size + image = _build_image_input(ctx, + image=Image.new("RGB", (width, height), + color=0)) + return {"image": [image] if num_images == 1 else [image] * num_images} + + +def dummy_data_for_minicpmv(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config() + num_images = mm_counts["image"] + + seq_data = dummy_seq_data_for_minicpmv(seq_len, num_images) + mm_data = dummy_image_for_minicpmv(ctx, hf_config, num_images) + + return seq_data, mm_data + + +def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + model_config = ctx.model_config + version = get_version_by_config(model_config.hf_config) + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + image_processor = cached_get_image_processor(model_config.tokenizer) + + def get_placeholder(image_size: Tuple[int, int], num_image: int): + if version == (2, 0) or version == (2, 5): + return image_processor. \ + get_slice_image_placeholder(image_size) + return image_processor. \ + get_slice_image_placeholder(image_size, num_image) + + prompt = llm_inputs.get("prompt") + token_ids = llm_inputs.get("prompt_token_ids") + if prompt is None: + prompt = tokenizer.decode(token_ids) + + pattern = "(./)" + images = multi_modal_data["image"] + image_tags = re.findall(pattern, prompt) + if len(image_tags) == 0: + new_token_ids = token_ids + new_prompt = prompt + else: + if isinstance(images, dict): + image_size_list = images.get("image_size_list") + images = [images.get("image_embeds")] + else: + if isinstance(images, Image.Image): + images = [images] + image_size_list = [image.size for image in images] + + text_chunks = prompt.split(pattern) + new_prompt_chunks: List[str] = [] + for i in range(len(image_size_list)): + new_prompt_chunks += [ + text_chunks[i], + get_placeholder(image_size_list[i], i) + ] + new_prompt_chunks.append(text_chunks[-1]) + new_prompt = "".join(new_prompt_chunks) + new_token_ids = tokenizer.encode(new_prompt) + + multi_modal_data["image"] = [ + _build_image_input(ctx, image) for image in images + ] + + llm_inputs = LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) + return llm_inputs + + +def input_mapper_for_minicpmv(ctx: InputContext, data: object): + model_config = ctx.model_config + + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + if not isinstance(data, list): + raise ValueError( + "Image input must be list of MiniCPMVImageInput, got (%s)", data) + + if len(data) > 0 and isinstance(data[0]['image'], torch.Tensor): + batch_data = { + "image_embeds": data[0]['image'], + } + else: + batch_data = image_processor \ + .preprocess([img["image"] for img in data], return_tensors="pt") \ + .data + + if len(data) > 0: + batch_data["im_start_id"] = data[0]["im_start_id"] + batch_data["im_end_id"] = data[0]["im_end_id"] + if "slice_start_id" in data[0]: + batch_data["slice_start_id"] = data[0]["slice_start_id"] + batch_data["slice_end_id"] = data[0]["slice_end_id"] + + return MultiModalInputs(batch_data) + + +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # and config class + self.config = config + self.multimodal_config = multimodal_config + + self.version = get_version_by_config(self.config) + self.llm = self.init_llm(config, cache_config, quant_config) + self.vpm = self.init_vision_module() + param_dtype = torch.get_default_dtype() + self.vpm.to(dtype=param_dtype) + self.vision_dim = (self.vpm.embed_dim if self.version == (2, 0) else + self.vpm.embeddings.embed_dim) + self.embed_dim = self.config.hidden_size + self.resampler = self.init_resampler(self.embed_dim, self.vision_dim) + self.resampler.to(device="cuda", dtype=param_dtype) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + self.make_empty_intermediate_tensors = ( + self.llm.make_empty_intermediate_tensors) + + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + vlm_embedding *= self.config.scale_emb + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) + else: + if image_inputs["type"] == "image_embeds": + vision_hidden_states = (image_inputs["data"].type( + vlm_embedding.dtype).to(vlm_embedding.device)) + else: + vision_hidden_states = self.get_vision_hidden_states( + image_inputs) + + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack([ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ]).to(vlm_embedding.device) + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, + vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, + vision_hidden_states.shape[-1]), + ) + + return vlm_embedding, vision_hidden_states + + def _get_image_bounds( + self, + input_ids: torch.Tensor, + im_start_id: torch.Tensor, + im_end_id: torch.Tensor, + slice_start_id: Optional[torch.Tensor] = None, + slice_end_id: Optional[torch.Tensor] = None) -> torch.Tensor: + # All the images in the batch should share the same special image + # bound token ids. + start_cond = input_ids == im_start_id[0] + end_cond = input_ids == im_end_id[0] + if slice_start_id is not None: + start_cond |= (input_ids == slice_start_id[0]) + end_cond |= (input_ids == slice_end_id[0]) + + image_start_tokens, = torch.where(start_cond) + image_start_tokens += 1 + image_end_tokens, = torch.where(end_cond) + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + + if valid_image_nums == 0: + return torch.zeros((0, 2), device=input_ids.device) + + return torch.hstack([ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ]) + + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", []) + tgt_sizes = kwargs.pop("tgt_sizes", []) + im_start_id = kwargs.pop("im_start_id", None) + im_end_id = kwargs.pop("im_end_id", None) + slice_start_id = kwargs.pop("slice_start_id", None) + slice_end_id = kwargs.pop("slice_end_id", None) + image_embeds = kwargs.pop("image_embeds", None) + + if image_embeds is not None: + return MiniCPMVImageEmbeddingInputs( + image_bounds=self._get_image_bounds(input_ids, im_start_id, + im_end_id, slice_start_id, + slice_end_id), + data=image_embeds, + type="image_embeds", + ) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of target sizes. " + f"Got type: {type(tgt_sizes)}") + + if len(pixel_values) != len(tgt_sizes): + raise ValueError("Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}") + + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + if len(pixel_b) != len(tgt_b): + raise ValueError("Inconsistent N lengths, found: " + f"{len(pixel_b)} vs {len(tgt_b)}") + + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + pixel_values_flat += pixel_n + tgt_sizes_flat += tgt_n + + # NOTE: Input IDs does not contain image tokens during memory profiling, + # so we allow it to be empty + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError("Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}") + + if len(pixel_values_flat) == 0: + return None + + if im_start_id is None: + return None + + return MiniCPMVImagePixelInputs( + image_bounds=self._get_image_bounds(input_ids, im_start_id, + im_end_id, slice_start_id, + slice_end_id), + data=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), + type="pixel_values", + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: Any, + ) -> torch.Tensor: + if intermediate_tensors is not None: + vlm_embeddings = None + else: + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + + output = self.llm( + input_ids=None, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=vlm_embeddings, + ) + return output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): + if key_to_modify in name: + name = name.replace(key_to_modify, new_key) + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + use_default_weight_loading = False + if self.is_default_weight_loading(name): + use_default_weight_loading = True + else: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if is_pp_missing_parameter( + name.replace(weight_name, param_name), self): + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + use_default_weight_loading = True + if use_default_weight_loading: + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field(language_model="llm", + connector="resampler", + tower_model="vpm") + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module(self) -> nn.Module: + raise NotImplementedError + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + raise NotImplementedError + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + raise NotImplementedError + + def is_default_weight_loading(self, name: str) -> bool: + raise NotImplementedError + + +class MiniCPMV2_0(MiniCPMVBaseModel): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 0) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + + return LLMWrapper(MiniCPMModel(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") + + def init_vision_module(self) -> nn.Module: + # TODO :refactor this vision model + try: + import timm + except ImportError: + raise ImportError("Please install timm==0.9.10") from ImportError + with set_default_torch_dtype(torch.float16): + model = timm.create_model( + "vit_so400m_patch14_siglip_384.webli", + pretrained=False, + num_classes=0, + dynamic_img_size=True, + dynamic_img_pad=True, + ) + + if (isinstance(model, timm.models.VisionTransformer) + and model.attn_pool is not None): + model.attn_pool = torch.nn.Identity() + + if self.config.drop_vision_last_layer: + model.blocks = model.blocks[:-1] + + return model + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_tokens(input_ids) + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2( + embed_dim=embed_dim, + num_heads=embed_dim // 128, + grid_size=int(math.sqrt(self.config.query_num)), + kv_dim=vision_dim, + adaptive=False, + do_post_projection=True, + ) + + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + res = [] + dtype = self.vpm.pos_embed.data.dtype + for pixel_value in pixel_values: + H, W = pixel_value[0].shape[-2:] + tgt_size = ( + math.ceil(H / self.vpm.patch_embed.patch_size[0]), + math.ceil(W / self.vpm.patch_embed.patch_size[0]), + ) + vision_embedding = self.vpm.forward_features( + pixel_value.unsqueeze(0).type(dtype)) + if (hasattr(self.vpm, "num_prefix_tokens") + and self.vpm.num_prefix_tokens > 0): + vision_embedding = vision_embedding[:, self.vpm. + num_prefix_tokens:] + res.append(self.resampler(vision_embedding, tgt_size)) + return torch.vstack(res) + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["data"] + + return self.get_vision_embedding(pixel_values) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name or "vpm" in name + + +class MiniCPMV2_5(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 5) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return LLMWrapper(LlamaModel(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") + + def init_vision_module(self) -> nn.Module: + model = Idefics2VisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm(pixel_values, + patch_attention_mask=patch_attn_mask) + vision_embedding = self.resampler(vision_embedding, tgt_sizes) + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["data"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + return self.get_vision_embedding(all_pixel_values.type(dtype), + patch_attn_mask, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name + + +class MiniCPMV2_6(MiniCPMVBaseModel, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config, multimodal_config, cache_config, quant_config) + assert self.version == (2, 6) + + def init_llm( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + + return LLMWrapper(Qwen2Model(config, + cache_config=cache_config, + quant_config=quant_config), + name="model") + + def init_vision_module(self) -> nn.Module: + + model = Idefics2VisionTransformer(self.config.vision_config) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + return model + + def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: + with set_default_torch_dtype(torch.float16): + # The resampler in 2.6 remains consistent with the one in 2.5. + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + ) + return resampler + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm( + pixel_values, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ) + return vision_embedding + + def get_vision_hidden_states(self, + data: MiniCPMVImageInputs) -> torch.Tensor: + pixel_values = data["data"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, + 1).reshape(B, 3, -1, L) + + patch_attn_mask = torch.zeros((B, 1, max_patches), + dtype=torch.bool, + device=device) + for i in range(B): + patch_attn_mask[i, 0, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes) + + def is_default_weight_loading(self, name: str) -> bool: + return "resampler" in name + + +_SUPPORT_VERSION = { + (2, 0): MiniCPMV2_0, + (2, 5): MiniCPMV2_5, + (2, 6): MiniCPMV2_6 +} + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_minicpmv) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_minicpmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_minicpmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_minicpmv) +class MiniCPMV(MiniCPMVBaseModel, SupportsLoRA): + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in vLLM. Therefore, it is necessary to separate them. + """ + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__(cls, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): + if not hasattr(config, "version"): + if config.hidden_size == 2304 and config.query_num == 64: + version = (2, 0) + else: + version = (2, 5) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + instance_class = _SUPPORT_VERSION.get(version) + if instance_class is None: + raise ValueError( + "Currently, MiniCPMV only supports versions 2.0, 2.5, and 2.6") + return instance_class(config, multimodal_config, cache_config, + quant_config) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py new file mode 100644 index 0000000..ff5d03f --- /dev/null +++ b/vllm/model_executor/models/mixtral.py @@ -0,0 +1,493 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import MixtralConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import CompressedTensorsConfig, CompressionFormat + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class MixtralMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class MixtralAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + if (isinstance(quant_config, CompressedTensorsConfig) + and quant_config.quant_format == CompressionFormat.int_quantized.value + and quant_config.target_scheme_map['Linear']['input_activations'].num_bits == 8 + and quant_config.target_scheme_map['Linear']['weights'].num_bits == 8): + self.use_int_w8a8 = True + else: + self.use_int_w8a8 = False + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + from ixformer.contrib.vllm.layers import mixtral_decoder_layer_forward + if self.use_int_w8a8: + return mixtral_decoder_layer_forward( + self, + positions, + hidden_states, + kv_cache, + attn_metadata, + residual + ) + + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "embed_tokens", "lm_head", "w1", "w2", "w3", + "gate" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = MixtralModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py new file mode 100644 index 0000000..63e2c60 --- /dev/null +++ b/vllm/model_executor/models/mixtral_quant.py @@ -0,0 +1,441 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import MixtralConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + quant_config=quant_config) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + quant_config=quant_config) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralMoE(nn.Module): + + def __init__( + self, + config: MixtralConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, + self.num_total_experts, + bias=False, + quant_config=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + num_tokens, hidden_dim) + + +class MixtralAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + quant_config: Optional[QuantizationConfig] = None, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config) + self.block_sparse_moe = MixtralMoE(config=config, + quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module, SupportsPP): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: MixtralConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = MixtralModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py new file mode 100644 index 0000000..034a018 --- /dev/null +++ b/vllm/model_executor/models/mllama.py @@ -0,0 +1,1151 @@ +# coding=utf-8 +# Copyright 2024 the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Mllama model.""" +import math +from array import array +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers.models.mllama.configuration_mllama as config_mllama +from PIL import Image +from torch import nn +from transformers.modeling_outputs import (BaseModelOutput, + CausalLMOutputWithPast) +from transformers.models.mllama.image_processing_mllama import ( + get_optimal_tiled_canvas) + +import vllm.distributed.parallel_state as ps +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData + +from .clip import CLIPMLP +from .interfaces import SupportsMultiModal +from .llama import LlamaDecoderLayer, LlamaMLP + +logger = init_logger(__name__) +MLLAMA_IMAGE_TOKEN_ID = 128256 +MLLAMA_IMAGE_TOKEN = "<|image|>" + + +class MllamaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: """ + """(batch_size, max_num_image, max_num_chunk, num_channel, height, width)""" + aspect_ratio_ids: torch.Tensor + """Shape: `(batch_size, max_num_image)`""" + aspect_ratio_mask: torch.Tensor + """Shape: `(batch_size, max_num_image, max_num_tiles)`""" + + +# TODO: support LlamaImageEmbeddingInputs + + +def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): + # move encoder_prompt to prompt + if llm_inputs.get("prompt") is None: + llm_inputs["prompt"] = llm_inputs["encoder_prompt"] + llm_inputs["prompt_token_ids"] = llm_inputs["encoder_prompt_token_ids"] + + # process multi-modal data + assert "decoder_multi_modal_data" not in llm_inputs, \ + "multi-modal data should be put in encoder message of mllama" + multi_modal_data = llm_inputs.get("encoder_multi_modal_data") + + if multi_modal_data is None or "image" not in multi_modal_data \ + or multi_modal_data["image"] is None: + # text-only + llm_inputs["encoder_prompt"] = "" + llm_inputs["encoder_prompt_token_ids"] = [] + llm_inputs["encoder_multi_modal_data"] = {} + return llm_inputs + + # get num_tiles + if isinstance(multi_modal_data['image'], Image.Image): + multi_modal_data['image'] = [multi_modal_data['image']] + hf_config = ctx.model_config.hf_config + num_tiles = 0 + for image in multi_modal_data["image"]: + width, height = image.size + tile_size = hf_config.vision_config.image_size + canvas_height, canvas_width = get_optimal_tiled_canvas( + image_height=height, + image_width=width, + max_image_tiles=hf_config.vision_config.max_num_tiles, + tile_size=tile_size, + ) + num_tiles_height = canvas_height // tile_size + num_tiles_width = canvas_width // tile_size + num_tiles += num_tiles_height * num_tiles_width + + # set encoder prompt based on num_tiles + assert hf_config.vision_config.image_size % 14 == 0, \ + "chunk size should be multiple of 14" + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + num_tokens = num_tiles * token_per_chunk + llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens + llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID + ] * num_tokens + + return llm_inputs + + +def get_max_mllama_image_tokens(ctx: InputContext) -> int: + hf_config = ctx.model_config.hf_config + token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + return hf_config.vision_config.max_num_tiles * token_per_chunk + + +def dummy_decoder_seq_data(seq_len: int, num_images: int): + # <|image|> * num_images + 0 * (seq_len - num_images) + assert seq_len >= num_images, \ + "seq_len should be greater than or equal to num_images" + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_images + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - num_images) + return SequenceData(token_ids) + + +def dummy_encoder_seq_data(ctx: InputContext, num_images: int): + num_tokens = get_max_mllama_image_tokens(ctx) * num_images + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [MLLAMA_IMAGE_TOKEN_ID]) * num_tokens + return SequenceData(token_ids) + + +def dummy_image(num_images: int, ): + width = height = 1024 + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_decoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_decoder_seq_data(seq_len, num_images), None + + +def dummy_encoder_data_for_mllama(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + num_images = mm_counts["image"] + return dummy_encoder_seq_data(ctx, num_images), dummy_image(num_images) + + +def _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask: torch.Tensor, + num_patches: int, + target_length: int, + dtype: torch.dtype, +) -> torch.Tensor: + # Expand aspect ratio mask to target_length + batch_size, max_num_tiles = aspect_ratio_mask.shape + attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, + 1).to(dtype) + attention_mask = attention_mask.repeat(1, 1, target_length, 1) + + # Mask padding patches + pad_patches = target_length - num_patches + attention_mask[:, :, -pad_patches:] = 0 + + # Invert the mask (0 -> 1, 1 -> 0) + attention_mask = 1 - attention_mask + + # Reshape to 2D and create 4D attention mask + # (batch_size, 1, max_num_tiles*target_length, max_num_tiles*target_length) + attention_mask = attention_mask.reshape(batch_size, + max_num_tiles * target_length, 1) + attention_mask = attention_mask @ attention_mask.transpose( + -1, -2) * torch.finfo(dtype).min + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + +class ColumnParallelConv2dPatch(torch.nn.Module): + """Conv2D Patching layer with model parallelism. + Column parallel over unfolded input. + Arguments: + in_channels: Input channels. + out_channels: Output channels. + kernel_size: Size of convolution kernel. + stride (default 1): Stride for convolution. + bias (default False): Use bias in Conv2d. + Input: (bsz, in_channels, width, height) + Output: (bsz, num_tokens, out_channels) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + bias: bool = False, + ) -> None: + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride) + self._linear = ColumnParallelLinear( + in_channels * kernel_size[0] * kernel_size[1], + out_channels, + bias=bias, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self._unfold(x) + x = x.permute(0, 2, 1).contiguous() + x, _ = self._linear(x) + return x + + +class MllamaPrecomputedAspectRatioEmbedding(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = True): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.is_gated = is_gated + + self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.hidden_size) + if is_gated: + self.gate = nn.Parameter(torch.zeros(1)) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + embeddings = self.embedding(aspect_ratio_ids) + embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, + self.hidden_size) + + if self.is_gated: + embeddings = embeddings * self.gate.tanh() + + hidden_state = hidden_state + embeddings + return hidden_state + + +class MllamaPrecomputedPositionEmbedding(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.max_num_tiles = config.max_num_tiles + self.max_aspect_ratio_id = config.max_aspect_ratio_id + self.num_patches = (config.image_size // config.patch_size)**2 + 1 + self.hidden_size = config.hidden_size + self.scale = config.hidden_size**-0.5 + + self.gate = nn.Parameter(torch.zeros(1)) + + # position embedding + position_embedding = torch.randn(self.num_patches, self.hidden_size) + self.embedding = nn.Parameter(self.scale * position_embedding) + + # tile position embedding + self.tile_embedding = nn.Embedding( + self.max_aspect_ratio_id + 1, + self.max_num_tiles * self.num_patches * self.hidden_size) + + def forward(self, hidden_state: torch.Tensor, + aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + # position embeddings + gated_position_embedding = (1 - self.gate.tanh()) * self.embedding + hidden_state = hidden_state + gated_position_embedding.view( + 1, 1, self.num_patches, self.hidden_size) + + # precomputed tile position embeddings + tile_position_embedding = self.tile_embedding(aspect_ratio_ids) + batch_size = hidden_state.shape[0] + tile_position_embedding = tile_position_embedding.reshape( + batch_size, self.max_num_tiles, self.num_patches, self.hidden_size) + gated_tile_position_embedding = self.gate.tanh( + ) * tile_position_embedding + hidden_state = hidden_state + gated_tile_position_embedding + + return hidden_state + + +# TODO: support other attention backends for attention in vision model +class MllamaVisionSdpaAttention(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + + model_parallel_size = get_tensor_model_parallel_world_size() + self.embed_dim = config.hidden_size + self.num_heads = config.attention_heads + self.head_dim = config.hidden_size // config.attention_heads + self.num_local_heads = self.num_heads // model_parallel_size + self.q_size = self.num_local_heads * self.head_dim + self.kv_size = self.num_local_heads * self.head_dim + self.scale = 1 / math.sqrt(self.head_dim) + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=False, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.embed_dim, + bias=False, + input_is_parallel=True, + ) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_state) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q = q.view(q.shape[0], q.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + k = k.view(k.shape[0], k.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + v = v.view(v.shape[0], v.shape[1], self.num_local_heads, + self.head_dim).transpose(1, 2) + + # TODO: remove padding in image encoder + attn_output = torch.empty_like(q) + for i in range(q.shape[0]): + attn = q[i] @ (k[i] * self.scale).permute(0,2,1) + if attention_mask is not None: + attn = attn + attention_mask[i] + attn = torch.softmax(attn, dim=-1) + + output = attn @ v[i] + attn_output[i] = output + # attn_output = F.scaled_dot_product_attention(q, + # k, + # v, + # attn_mask=attention_mask, + # dropout_p=0.0) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(attn_output.shape[0], + attn_output.shape[1], -1) + output, _ = self.o_proj(attn_output) + return output + + +class MllamaVisionEncoderLayer(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + is_gated: bool = False): + super().__init__() + + self.hidden_size = config.hidden_size + self.num_attention_heads = config.attention_heads + self.is_gated = is_gated + self.intermediate_size = config.intermediate_size + + self.self_attn = MllamaVisionSdpaAttention(config) + self.mlp = CLIPMLP(config) + + self.input_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, + eps=config.norm_eps) + + # there used to be an if else here, no code path + if is_gated: + self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4) + self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4) + + def forward( + self, + hidden_state: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + # Self Attention + residual = hidden_state + hidden_state = self.input_layernorm(hidden_state) + hidden_state = self.self_attn(hidden_state, + attention_mask=attention_mask) + gate_attn = 1 if not self.is_gated else self.gate_attn.tanh() + hidden_state = residual + gate_attn * hidden_state + + # Feed forward + residual = hidden_state + hidden_state = self.post_attention_layernorm(hidden_state) + hidden_state = self.mlp(hidden_state) + gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh() + hidden_state = residual + gate_ffn * hidden_state + + return hidden_state + + +class MllamaVisionEncoder(nn.Module): + + def __init__(self, + config: config_mllama.MllamaVisionConfig, + num_layers=32, + is_gated=False, + output_hidden_states=None): + super().__init__() + self.config = config + self.layers = nn.ModuleList([ + MllamaVisionEncoderLayer(config, is_gated) + for _ in range(num_layers) + ]) + self.output_hidden_states = output_hidden_states or [] + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> Union[Tuple, BaseModelOutput]: + encoder_states = () + + for i, encoder_layer in enumerate(self.layers): + if i in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + hidden_states = encoder_layer( + hidden_states, + attention_mask, + ) + + if len(self.layers) - 1 in self.output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + return hidden_states, encoder_states + + +class MllamaVisionModel(nn.Module): + + def __init__(self, config: config_mllama.MllamaVisionConfig): + super().__init__() + self.image_size = config.image_size + self.patch_size = config.patch_size + self.max_num_tiles = config.max_num_tiles + self.hidden_size = config.hidden_size + self.in_channels = config.num_channels + self.intermediate_layers_indices = config.intermediate_layers_indices + + self.num_patches = (self.image_size // self.patch_size)**2 + 1 + self.scale = config.hidden_size**-0.5 + + self.patch_embedding = ColumnParallelConv2dPatch( + in_channels=config.num_channels, + out_channels=self.hidden_size, + kernel_size=self.patch_size, + stride=self.patch_size, + bias=False, + ) + + self.class_embedding = nn.Parameter(self.scale * + torch.randn(self.hidden_size)) + self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding( + config) + + self.pre_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + self.post_tile_positional_embedding = \ + MllamaPrecomputedAspectRatioEmbedding(config, is_gated=True) + + # layer norms + self.layernorm_pre = nn.LayerNorm(self.hidden_size) + self.layernorm_post = nn.LayerNorm(self.hidden_size) + + # encoders + self.transformer = MllamaVisionEncoder( + config, + config.num_hidden_layers, + is_gated=False, + output_hidden_states=config.intermediate_layers_indices) + self.global_transformer = MllamaVisionEncoder(config, + config.num_global_layers, + is_gated=True) + + def apply_class_embedding(self, + hidden_state: torch.Tensor) -> torch.Tensor: + batch_size, _, hidden_size = hidden_state.shape + class_embedding = self.class_embedding.expand(batch_size, 1, + hidden_size) + hidden_state = torch.cat([class_embedding, hidden_state], dim=1) + return hidden_state + + def forward(self, pixel_values: torch.Tensor, + aspect_ratio_ids: torch.Tensor, + aspect_ratio_mask: torch.Tensor) -> torch.Tensor: + batch_size, num_concurrent_media, num_tiles, num_channels, \ + height, width = pixel_values.shape + + pixel_values = pixel_values.reshape( + batch_size * num_concurrent_media * num_tiles, num_channels, + height, width) + aspect_ratio_ids = aspect_ratio_ids.reshape( + batch_size * num_concurrent_media, -1) + + # patch embedding + patch_embeds = self.patch_embedding( + pixel_values.to(self.layernorm_pre.weight.dtype)) + hidden_state = patch_embeds + hidden_state = ps.get_tp_group().all_gather(hidden_state) + + # tile embeddings + _, num_patches, dim = hidden_state.shape + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, -1, dim) + hidden_state = self.pre_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + + # apply cls token + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media * num_tiles, num_patches, dim) + hidden_state = self.apply_class_embedding(hidden_state) + num_patches += 1 + + # apply position embeddings + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, num_patches, dim) + hidden_state = self.gated_positional_embedding(hidden_state, + aspect_ratio_ids) + + # apply encoder + hidden_state = self.layernorm_pre(hidden_state) + + # Compute the number of tokens to pad + num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8 + # Compute padding tuple for pad function + padding = ( + 0, 0, 0, num_padding_patches + ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2) + # Pad the tensor + hidden_state = F.pad(hidden_state, padding, mode="constant", value=0) + slice_index = -num_padding_patches if num_padding_patches > 0 else None + + attention_mask = aspect_ratio_mask.reshape( + batch_size * num_concurrent_media, -1) + attention_mask = _prepare_aspect_ratio_attention_mask( + aspect_ratio_mask=attention_mask, + num_patches=self.num_patches, + target_length=hidden_state.shape[2], + dtype=self.layernorm_pre.weight.dtype, + ) + + hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, + dim) + output = self.transformer( + hidden_state, + attention_mask=attention_mask, + ) + hidden_state, intermediate_hidden_states = output[0], output[1] + intermediate_hidden_states = torch.stack(intermediate_hidden_states, + dim=-1) + + # apply global encoder + hidden_state = self.layernorm_post(hidden_state) + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = self.post_tile_positional_embedding( + hidden_state, aspect_ratio_ids) + hidden_state = hidden_state.reshape( + batch_size * num_concurrent_media, + num_tiles * (num_patches + num_padding_patches), dim) + hidden_state = self.global_transformer( + hidden_state, attention_mask=attention_mask)[0] + hidden_state = hidden_state.reshape(batch_size * num_concurrent_media, + num_tiles, + num_patches + num_padding_patches, + dim) + hidden_state = hidden_state[:, :, :slice_index] + + # adding intermediate layer outputs + hidden_state = hidden_state.reshape(batch_size, num_concurrent_media, + num_tiles, num_patches, dim) + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size * num_concurrent_media, num_tiles, + num_patches + num_padding_patches, -1) + intermediate_hidden_states = intermediate_hidden_states[:, :, : + slice_index] + intermediate_hidden_states = intermediate_hidden_states.reshape( + batch_size, num_concurrent_media, num_tiles, num_patches, -1) + hidden_state = torch.cat([hidden_state, intermediate_hidden_states], + dim=-1) + return hidden_state + + +class MllamaTextRMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + """ + MllamaTextRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class MllamaTextCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Optional[config_mllama.MllamaTextConfig] = None, + layer_idx: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.model_parallel_size = get_tensor_model_parallel_world_size() + self.num_heads = self.config.num_attention_heads + self.num_local_heads = self.num_heads // self.model_parallel_size + self.num_key_value_heads = self.config.num_key_value_heads + self.num_local_key_value_heads = \ + self.num_key_value_heads // self.model_parallel_size + self.dropout = config.dropout + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // self.num_heads + self.layer_idx = layer_idx + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.q_local_size = self.num_local_heads * self.head_dim + self.kv_local_size = self.num_local_key_value_heads * self.head_dim + + # TODO: change to Q/KV separate linear after #7448 is merged + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.num_heads * self.head_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + quant_config=quant_config, + ) + # vllm.model_executor.layers.layernorm.RMSNorm has precision issue, + # use huggingface's instead + self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.scaling = self.head_dim**-0.5 + + self.attn = Attention( + self.num_local_heads, + self.head_dim, + self.scaling, + self.num_local_key_value_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + cross_attention_states: Optional[torch.Tensor], + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv_dec, _ = self.qkv_proj(hidden_states) + q, _, _ = qkv_dec.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + if cross_attention_states is None: + k = None + v = None + else: + qkv_enc, _ = self.qkv_proj(cross_attention_states) + _, k, v = qkv_enc.split( + [self.q_local_size, self.kv_local_size, self.kv_local_size], + dim=-1) + k = k.view(-1, self.num_local_key_value_heads, self.head_dim) + v = v.view(-1, self.num_local_key_value_heads, self.head_dim) + k = self.k_norm(k) + q = q.view(-1, self.num_local_heads, self.head_dim) + q = self.q_norm(q) + + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) + out, _ = self.o_proj(output) + return out + + +class MllamaCrossAttentionDecoderLayer(torch.nn.Module): + """Cross-attention transformer block with tanh-gated attention + and feedforward.""" + + def __init__(self, config: config_mllama.MllamaTextConfig, layer_idx: int, + quant_config: Optional[QuantizationConfig]) \ + -> None: + super().__init__() + self.layer_idx = layer_idx + self.cross_attn = MllamaTextCrossAttention( + config=config, + layer_idx=layer_idx, + quant_config=quant_config, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1)) + + self.mlp = LlamaMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + + def forward( + self, + hidden_states: torch.Tensor, + cross_attention_states: torch.Tensor, + cross_attention_mask: torch.Tensor, + full_text_row_masked_out_mask: torch.Tensor, + kv_cache: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.cross_attn( + hidden_states=hidden_states, + attention_mask=cross_attention_mask, + cross_attention_states=cross_attention_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_attn_gate.tanh( + ) * hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = full_text_row_masked_out_mask * hidden_states + hidden_states = residual + self.cross_attn_mlp_gate.tanh( + ) * hidden_states + return hidden_states + + +class MllamaTextModel(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "model" + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding(config.vocab_size + 8, + config.hidden_size) + self.cross_attention_layers = config.cross_attention_layers + + layers = [] + for layer_idx in range(config.num_hidden_layers): + if layer_idx in self.cross_attention_layers: + layers.append( + MllamaCrossAttentionDecoderLayer( + config, layer_idx, quant_config=quant_config)) + else: + # TODO: force LlamaDecoderLayer to config.attention_bias=False + layers.append( + LlamaDecoderLayer(config, + cache_config=cache_config, + quant_config=quant_config)) + + self.layers = nn.ModuleList(layers) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + hidden_states = inputs_embeds + + for idx, decoder_layer in enumerate(self.layers): + if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): + if not skip_cross_attention: + hidden_states = decoder_layer( + hidden_states=hidden_states, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask= + full_text_row_masked_out_mask, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + ) + elif isinstance(decoder_layer, LlamaDecoderLayer): + hidden_states, residual = decoder_layer( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_caches[idx], + attn_metadata=attn_metadata, + residual=None, + ) + hidden_states = hidden_states + residual + else: + raise ValueError( + f"Unknown decoder layer type {type(decoder_layer)}") + hidden_states = self.norm(hidden_states) + return hidden_states + + +class MllamaForCausalLM(nn.Module): + config_class = config_mllama.MllamaTextConfig + base_model_prefix = "language_model" + _no_split_modules = [ + "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer" + ] + + def __init__(self, config: config_mllama.MllamaTextConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig]): + super().__init__() + self.vocab_size = config.vocab_size + self.model = MllamaTextModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + cross_attention_states: Optional[torch.LongTensor], + cross_attention_mask: Optional[torch.LongTensor], + full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, + torch.Tensor]], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + skip_cross_attention: bool, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_mllama_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) +@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) +@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): + + def __init__(self, + config: config_mllama.MllamaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.vocab_size = config.text_config.vocab_size + self.hidden_size = config.text_config.hidden_size + self.max_num_tiles = config.vision_config.max_num_tiles + self.vision_output_dim = config.vision_config.vision_output_dim + self.pad_token_id = \ + config.pad_token_id if config.pad_token_id is not None else -1 + self.image_size = config.vision_config.image_size + + self.vision_model = MllamaVisionModel(config.vision_config) + self.language_model = MllamaForCausalLM( + config.text_config, + cache_config=cache_config, + quant_config=quant_config, + ) + self.multi_modal_projector = nn.Linear( + config.vision_config.vision_output_dim, + config.text_config.hidden_size, + bias=True, + ) + self.logits_processor = LogitsProcessor(config.output_hidden_states, + config.text_config.vocab_size) + self.sampler = Sampler() + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.language_model.lm_head, + hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def _parse_and_validate_image_input(self, **kwargs: object): + # tensor with the same shape will be batched together by + # MultiModalInputs.batch, so pixel_values here can be: + # - List[List[torch.Tensor]]: + # with shape (num_tiles, 3, image_res, image_res) + # - List[torch.Tensor]: + # with shape (num_image, num_tiles, 3, image_res, image_res) + # - torch.Tensor: + # with shape (bs, num_image, num_tiles, 3, image_res, image_res) + pixel_values: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "pixel_values", None) + image_embeds: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "image_embeds", None) + aspect_ratio_ids: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_ids", None) + aspect_ratio_mask: Optional[Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]] = kwargs.pop( + "aspect_ratio_mask", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None and image_embeds is not None: + raise ValueError( + "Both pixel values and image embeds are provided.") + + if pixel_values is not None: + assert aspect_ratio_ids is not None + assert aspect_ratio_mask is not None + max_num_images = max([len(x[0]) for x in pixel_values]) + if max_num_images == 0: + raise ValueError("No images provided.") + max_num_tiles = max( + max([len(x) for x in y[0]]) for y in pixel_values) + device = self.multi_modal_projector.weight.device + bsz = len(pixel_values) + out_num_tiles = [] + out_images = torch.zeros( + bsz, + max_num_images, + max_num_tiles, + 3, + self.image_size, + self.image_size, + dtype=torch.float32, + device=device, + ) + out_ar_ids = torch.ones(bsz, + max_num_images, + dtype=torch.int64, + device=device) + out_ar_mask = torch.zeros(bsz, + max_num_images, + max_num_tiles, + dtype=torch.int64, + device=device) + for b in range(len(pixel_values)): + _num_tiles = [] + for i in range(len(pixel_values[b][0])): + img = pixel_values[b][0][i] + out_images[b, i, :img.shape[0]] = img + out_ar_ids[b, i] = aspect_ratio_ids[b][0][i] + out_ar_mask[b, i] = aspect_ratio_mask[b][0][i] + _num_tiles.append(img.shape[0]) + out_num_tiles.append(_num_tiles) + + return MllamaImagePixelInputs( + type="pixel_values", + data=out_images, + aspect_ratio_ids=out_ar_ids, + aspect_ratio_mask=out_ar_mask, + ) + + if image_embeds is not None: + raise NotImplementedError + + raise AssertionError("This line should be unreachable.") + + def flat_encoder_result(self, cross_attention_states: torch.Tensor, + attn_metadata: AttentionMetadata): + + cross_attention_states_flat = torch.zeros( + sum(attn_metadata.encoder_seq_lens), + cross_attention_states.shape[-1], + device=cross_attention_states.device, + dtype=cross_attention_states.dtype) + start_pos = 0 + for seq_len, vision_token_in_batch in zip( + attn_metadata.encoder_seq_lens, cross_attention_states): + end_pos = start_pos + seq_len + cross_attention_states_flat[ + start_pos:end_pos] = vision_token_in_batch[:seq_len] + start_pos = end_pos + cross_attention_states = cross_attention_states_flat + + full_text_row_masked_out_mask = torch.ones( + (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) + start_pos = 0 + for seq_len, encoder_seq_len in zip( + attn_metadata.seq_lens_tensor.cpu(), + attn_metadata.encoder_seq_lens): + if encoder_seq_len == 0: + full_text_row_masked_out_mask[start_pos:start_pos + + seq_len] = False + start_pos += seq_len + full_text_row_masked_out_mask = full_text_row_masked_out_mask.to( + cross_attention_states.device) + + return cross_attention_states, full_text_row_masked_out_mask + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> Union[Tuple, CausalLMOutputWithPast]: + if attn_metadata.num_prefill_tokens > 0 and \ + attn_metadata.num_decode_tokens > 0: + raise ValueError("Chunk prefill not supported") + image_inputs = self._parse_and_validate_image_input(**kwargs) + if image_inputs is None: + cross_attention_mask = None + full_text_row_masked_out_mask = ( + attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( + input_ids.device) + cross_attention_states = None + skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + else: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states, full_text_row_masked_out_mask = \ + self.flat_encoder_result(cross_attention_states, attn_metadata) + skip_cross_attention = False + # TODO: support multi-image by this mask + cross_attention_mask = None + + outputs = self.language_model( + input_ids=input_ids, + positions=positions, + cross_attention_states=cross_attention_states, + cross_attention_mask=cross_attention_mask, + full_text_row_masked_out_mask=full_text_row_masked_out_mask, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + skip_cross_attention=skip_cross_attention, + ) + + return outputs + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + updated_params = set() + for name, loaded_weight in weights: + if 'patch_embedding.weight' in name: + name = name.replace('patch_embedding.weight', + 'patch_embedding._linear.weight') + loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + updated_params.add(name) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.pop(name) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py new file mode 100644 index 0000000..42ccd01 --- /dev/null +++ b/vllm/model_executor/models/mlp_speculator.py @@ -0,0 +1,197 @@ +import math +from typing import Iterable, List, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.transformers_utils.configs import MLPSpeculatorConfig + +SQRT2 = 2**0.5 + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + eps : float + Safety term to prevent division by zero. Make sure the chosen value + fits in the range of your encoding scheme + (i.e. fp16 requires eps >= 6e-8). + elementwise_scale_and_shift : bool + Include a learned scaling and shift term after normalization. + """ + + def __init__( + self, + normalized_shape, + eps=1e-06, + elementwise_scale_and_shift=True, + ): + super(MLPSpeculatorLayerNorm, self).__init__() + self.elementwise_scale_and_shift = elementwise_scale_and_shift + if self.elementwise_scale_and_shift: + self.weight = nn.Parameter(torch.empty(normalized_shape)) + self.bias = nn.Parameter(torch.empty(normalized_shape)) + self.eps = eps + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + if self.elementwise_scale_and_shift: + x = self.weight * x + x = x + self.bias + return x + + +class MLPSpeculator(nn.Module): + """ + An implementation of the speculative models introduced in + "Accelerating Production LLMs with Combined Token/Embedding + Speculators" + https://arxiv.org/pdf/2404.19124 + + Trained speculators of this type are available on HF hub at: + https://huggingface.co/ibm-fms and https://huggingface.co/ibm-granite + """ + + def __init__(self, config: MLPSpeculatorConfig, **kwargs) -> None: + super().__init__() + self.n_predict = config.n_predict + self.vocab_size = config.vocab_size + self.emb_dim = config.emb_dim + self.inner_dim = config.inner_dim if config.inner_dim != 0 \ + else config.emb_dim + + self.max_speculative_tokens = config.num_lookahead_tokens + + self.tie_weights = config.tie_weights + self.scale_input = config.scale_input + + if self.tie_weights: + assert ( + self.n_predict > + 1), "You cannot tie weights between stages when only 1 exists" + embedding = VocabParallelEmbedding( + config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens) + + # the initial projection from the base model may + # have a different size, so that stays separate. + proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False) + proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False) + self.proj = nn.ModuleList([proj_first] + [proj_tied] * + (self.max_speculative_tokens - 1)) + + head = ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) + self.head = nn.ModuleList([head] * self.max_speculative_tokens) + + ln = MLPSpeculatorLayerNorm(self.inner_dim, + elementwise_scale_and_shift=True) + self.ln = nn.ModuleList([ln] * self.max_speculative_tokens) + + else: + self.emb = nn.ModuleList([ + VocabParallelEmbedding(config.vocab_size, + self.inner_dim, + org_num_embeddings=config.vocab_size) + for _ in range(self.max_speculative_tokens) + ]) + + self.proj = nn.ModuleList([ + nn.Linear((self.emb_dim if i == 0 else self.inner_dim), + self.inner_dim, + bias=False) + for i in range(self.max_speculative_tokens) + ]) + + self.head = nn.ModuleList([ + ParallelLMHead(self.vocab_size, self.inner_dim, bias=False) + for _ in range(self.max_speculative_tokens) + ]) + self.ln = nn.ModuleList([ + MLPSpeculatorLayerNorm(self.inner_dim, + elementwise_scale_and_shift=True) + for _ in range(self.max_speculative_tokens) + ]) + if self.scale_input: + self.ln0 = MLPSpeculatorLayerNorm( + self.emb_dim, elementwise_scale_and_shift=False) + + self.state_weight = 0.5**(0.5 / config.n_predict) + self.emb_weight = math.sqrt( + (1 - self.state_weight**2) * (self.inner_dim / 2)) + self.activation = nn.GELU() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + config.vocab_size, 1.0) + self.sampler = Sampler() + + def generate_proposals( + self, + input_ids: torch.Tensor, + previous_hidden_states: torch.Tensor, + num_predict_tokens: int, + sampling_metadata: SamplingMetadata, + ) -> List[SamplerOutput]: + if num_predict_tokens > self.max_speculative_tokens: + raise ValueError(f"Max speculative tokens for model is " + f"{self.max_speculative_tokens}, but " + f"{num_predict_tokens} were requested") + + # b x 1 x d + previous_hidden_states = previous_hidden_states.unsqueeze(1) + + if self.scale_input: + previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2 + + # b x 1 + last_tokens = input_ids.unsqueeze(1) + + next_tokens = [] + + for head_index in range(num_predict_tokens): + + # Project and predict + z = self.emb[head_index](last_tokens) # b k d + states = self.proj[head_index](previous_hidden_states) + + # Weighted add of state_weight*state and emb_weight*z + # Let subsequent LN take care of denominator + # state_weight is close to 1, so shouldn't be any precision issues + states.add_(z, alpha=self.emb_weight / self.state_weight) + + states = self.activation(self.ln[head_index](states)) # b k d + previous_hidden_states = states + # TODO: not yet supporting top_k_tokens_per_head + states = states.flatten(0, 1) + + logits = self.logits_processor(self.head[head_index], states, + sampling_metadata) + + output = self.sampler(logits, sampling_metadata) + last_tokens = output.sampled_token_ids + next_tokens.append(output) + + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + param = params_dict.get(name.replace("speculator.", "")) + if param is not None: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/module_mapping.py b/vllm/model_executor/models/module_mapping.py new file mode 100644 index 0000000..a9102a6 --- /dev/null +++ b/vllm/model_executor/models/module_mapping.py @@ -0,0 +1,69 @@ +# Adapted from +# https://github.com/modelscope/ms-swift/blob/v2.4.2/swift/utils/module_mapping.py + +from dataclasses import dataclass, field +from typing import List, Union + + +@dataclass +class ModelKeys: + model_type: str = None + + module_list: str = None + + embedding: str = None + + mlp: str = None + + down_proj: str = None + + attention: str = None + + o_proj: str = None + + q_proj: str = None + + k_proj: str = None + + v_proj: str = None + + qkv_proj: str = None + + qk_proj: str = None + + qa_proj: str = None + + qb_proj: str = None + + kva_proj: str = None + + kvb_proj: str = None + + output: str = None + + +@dataclass +class MultiModelKeys(ModelKeys): + language_model: List[str] = field(default_factory=list) + connector: List[str] = field(default_factory=list) + # vision tower and audio tower + tower_model: List[str] = field(default_factory=list) + generator: List[str] = field(default_factory=list) + + @staticmethod + def from_string_field(language_model: Union[str, List[str]] = None, + connector: Union[str, List[str]] = None, + tower_model: Union[str, List[str]] = None, + generator: Union[str, List[str]] = None, + **kwargs) -> 'MultiModelKeys': + + def to_list(value): + if value is None: + return [] + return [value] if isinstance(value, str) else list(value) + + return MultiModelKeys(language_model=to_list(language_model), + connector=to_list(connector), + tower_model=to_list(tower_model), + generator=to_list(generator), + **kwargs) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py new file mode 100644 index 0000000..ccfee16 --- /dev/null +++ b/vllm/model_executor/models/molmo.py @@ -0,0 +1,1290 @@ +import logging +import math +import re +from array import array +from dataclasses import dataclass +from functools import lru_cache, partial +from typing import (Any, Iterable, List, Mapping, Optional, Tuple, TypedDict, + Union) + +import torch +from einops import rearrange +from PIL import Image +from torch import nn +from torch.nn import functional as F +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.attention import Attention, AttentionMetadata +from vllm.attention.selector import (_Backend, backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import make_layers +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs +from vllm.platforms import current_platform +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) +from vllm.transformers_utils.processor import get_processor + +log = logging.getLogger(__name__) + +# TODO: hard-coded for now. Consider making it configurable. +VIT_LAYERS = [-2, -9] +NUM_PREFIX_TOKENS = 1 +ADDITIONAL_VOCAB_SIZE = 128 + + +class MolmoImageInputs(TypedDict): + images: torch.Tensor + """Shape: + `(batch_size, num_crops, num_patch, patch_dim)` + """ + + image_input_idx: torch.Tensor + """Shape: + `(batch_size, num_crops, num_patch)` + """ + + seq_len: torch.Tensor + """Shape: + `(batch_size, )` + """ + + image_masks: Optional[torch.Tensor] + """Shape: + `(batch_size, num_crops, num_patch)` + """ + + +@dataclass +class VisionBackboneConfig: + image_default_input_size: Tuple[int, int] = (336, 336) + image_patch_size: int = 14 + image_pos_patch_size: int = 14 + image_emb_dim: int = 1024 + image_num_heads: int = 16 + image_num_key_value_heads: int = 16 + image_num_layers: int = 23 + image_mlp_dim: int = 4096 + image_mlp_activations: str = "quick_gelu" + image_num_pos: int = 577 + image_norm_eps: float = 1e-5 + + def __post_init__(self): + self.image_default_input_size = tuple( + self.image_default_input_size) # type: ignore[assignment] + + @property + def image_num_patch(self): + h, w = self.image_default_input_size + return h // self.image_patch_size, w // self.image_patch_size + + +class ViTMLP(nn.Module): + """MLP used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.w1 = ColumnParallelLinear( + config.image_emb_dim, + config.image_mlp_dim, + bias=True, + quant_config=quant_config, + ) + # Activation function. + assert config.image_mlp_activations == "quick_gelu" + self.act = QuickGELU() + self.w2 = RowParallelLinear( + config.image_mlp_dim, + config.image_emb_dim, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.w1(x) + x = self.act(x) + x, _ = self.w2(x) + return x + + +class MultiHeadDotProductAttention(nn.Module): + """Multi-head attention used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + use_bias: bool = True, + nlayers: int = 1, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.hidden_size = config.image_emb_dim + self.total_num_heads = config.image_num_heads + tp_size = get_tensor_model_parallel_world_size() + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + + self.total_num_kv_heads = config.image_num_key_value_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.wq = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wk = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wv = ColumnParallelLinear( + nlayers * self.hidden_size, + self.total_num_kv_heads * self.head_dim, + bias=use_bias, + quant_config=quant_config, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=use_bias, + quant_config=quant_config, + ) + + # Detect attention implementation. + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + # For Volta and Turing GPUs, use xformers instead. + device_available = current_platform.get_device_capability()[0] >= 8 + if device_available: + from transformers.utils import is_flash_attn_2_available + if is_flash_attn_2_available(): + self._use_flash_attn = True + else: + log.warning( + "Current Molmo implementation has a bug with " + "`vllm-flash-attn` inside vision module, so we use " + "xformers backend instead. You can run `pip install " + "flash-attn to use flash-attention backend.") + self._use_flash_attn = False + else: + self._use_flash_attn = False + else: + if selected_backend == _Backend.FLASH_ATTN: + self._use_flash_attn = True + elif selected_backend == _Backend.XFORMERS: + self._use_flash_attn = False + else: + raise RuntimeError( + f"Molmo does not support {selected_backend} backend now.") + + def forward(self, + inputs_q: torch.Tensor, + inputs_kv: Optional[torch.Tensor] = None) -> torch.Tensor: + + if inputs_kv is not None: + inputs_k = inputs_kv + inputs_v = inputs_kv + else: + inputs_k = inputs_q + inputs_v = inputs_q + + xq, _ = self.wq(inputs_q) + xk, _ = self.wk(inputs_k) + xv, _ = self.wv(inputs_v) + q_shape = xq.size()[:-1] + (self.num_heads, self.head_dim) + kv_shape = xk.size()[:-1] + (self.num_kv_heads, self.head_dim) + xq = xq.view(*q_shape) + xk = xk.view(*kv_shape) + xv = xv.view(*kv_shape) + + if self._use_flash_attn: + from flash_attn import flash_attn_func + output = flash_attn_func(xq, xk, xv, dropout_p=0.0, causal=False) + else: + from xformers import ops as xops + output = xops.memory_efficient_attention_forward(xq, xk, xv, p=0) + + output = rearrange(output, "b s h d -> b s (h d)").contiguous() + output, _ = self.wo(output) + + return output + + +class ResidualAttentionBlock(nn.Module): + """Residual attention block used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.attention = MultiHeadDotProductAttention( + config, quant_config=quant_config) + self.feed_forward = ViTMLP(config, quant_config) + self.attention_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + self.ffn_norm = nn.LayerNorm( + config.image_emb_dim, + eps=config.image_norm_eps, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attention(self.attention_norm(x)) + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + +class BlockCollection(nn.Module): + """Collection of residual attention blocks used in Vision Transformer.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.resblocks = nn.ModuleList([ + ResidualAttentionBlock(config, quant_config) + for _ in range(config.image_num_layers) + ]) + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + hidden_states = [] + for r in self.resblocks: + x = r(x) + hidden_states.append(x) + return hidden_states + + +def _expand_token(token: torch.Tensor, batch_size: int) -> torch.Tensor: + return token.view(1, 1, -1).expand(batch_size, -1, -1) + + +class VisionTransformer(nn.Module): + """Vision Transformer used in Vision Backbone.""" + + def __init__( + self, + config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + scale = config.image_emb_dim**-0.5 + self.patch_num = config.image_num_patch + self.class_embedding = nn.Parameter( + torch.randn(config.image_emb_dim) * scale) + self.num_prefix_tokens: int = NUM_PREFIX_TOKENS + self.positional_embedding = nn.Parameter( + torch.randn(config.image_num_pos, config.image_emb_dim) * scale) + image_patch_size = config.image_patch_size + self.patch_embedding = nn.Linear( + image_patch_size * image_patch_size * 3, + config.image_emb_dim, + bias=False, + ) + self.pre_ln = nn.LayerNorm(config.image_emb_dim, + eps=config.image_norm_eps) + self.transformer = BlockCollection(config, quant_config) + + def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor: + cls_emb = self.positional_embedding[0:1] + pos_emb = self.positional_embedding[1:] + + pos_emb = pos_emb.reshape( + (int(math.sqrt(pos_emb.shape[0])), + int(math.sqrt(pos_emb.shape[0])), pos_emb.shape[1])) + + (patch_num_0, patch_num_1) = patch_num + + if pos_emb.shape[0] != patch_num_0 or pos_emb.shape[1] != patch_num_1: + # from https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + pos_emb = pos_emb.unsqueeze(0).permute(0, 3, 1, 2) + pos_emb = F.interpolate( + pos_emb, + size=(patch_num_0, patch_num_1), + mode="bicubic", + align_corners=False, + antialias=True, + ) + pos_emb = pos_emb.permute(0, 2, 3, 1).squeeze(0) + + pos_emb = pos_emb.reshape(-1, pos_emb.shape[-1]) + x = x + torch.cat([cls_emb[None, :, :], pos_emb[None, :, :]], + dim=1).to(x.dtype) + return x + + def forward(self, + x: torch.Tensor, + patch_num: int = None) -> List[torch.Tensor]: + """ + : param x: (batch_size, num_patch, n_pixels) + """ + if patch_num is None: + patch_num = self.patch_num + B, N, D = x.shape + + x = self.patch_embedding(x) + + # class embeddings and positional embeddings + x = torch.cat( + [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], + dim=1) + x = self.add_pos_emb(x, patch_num) + + x = self.pre_ln(x) + + hidden_states = self.transformer(x) + return hidden_states + + +class MolmoAttention(nn.Module): + """Molmo's LLM attention.""" + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % self.tp_size == 0 + + self.num_heads = self.total_num_heads // self.tp_size + self.total_num_kv_heads = config.num_key_value_heads \ + or self.total_num_heads + if self.total_num_kv_heads >= self.tp_size: + assert self.total_num_kv_heads % self.tp_size == 0 + else: + assert self.tp_size % self.total_num_kv_heads == 0 + + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.tp_rank: Optional[int] = None + self.k_norm: Optional[nn.Module] = None + self.q_norm: Optional[nn.Module] = None + if config.attention_layer_norm: + self.tp_rank = get_tensor_model_parallel_rank() + self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim, + eps=config.layer_norm_eps) + self.q_norm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def _apply_qk_norm(self, q: torch.Tensor, + k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.q_norm is not None and self.k_norm is not None: + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MolmoMLP(nn.Module): + """Molmo's LLM mlp.""" + + def __init__( + self, + config: PretrainedConfig, + input_dim: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + input_dim or self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class MolmoDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # Attention block. + self.self_attn = MolmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = MolmoMLP(config, quant_config=quant_config) + + # LayerNorm + assert config.layer_norm_type == "rms" + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class MolmoDecoderNormAfterLayer(MolmoDecoderLayer): + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Self Attention + residual = hidden_states + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states = self.input_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = None + return hidden_states, residual + + +class MolmoVisionBackbone(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + vision_config: VisionBackboneConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.vit_layers = VIT_LAYERS + self.image_num_patch = vision_config.image_num_patch + self.llm_patches_per_crop = ( + (self.image_num_patch[0] + 1) // 2, + (self.image_num_patch[1] + 1) // 2, + ) + self.image_vit = VisionTransformer(vision_config, + quant_config=quant_config) + self.num_prefix_tokens = self.image_vit.num_prefix_tokens + assert self.num_prefix_tokens in { + 0, 1 + }, "Only 0 or 1 prefix tokens are supported" + self.image_pooling_2d = MultiHeadDotProductAttention( + vision_config, + nlayers=len(self.vit_layers), + quant_config=quant_config) + self.image_projector = MolmoMLP( + config, + input_dim=vision_config.image_emb_dim, + quant_config=quant_config, + ) + + image_dim = vision_config.image_emb_dim * len(self.vit_layers) + self.pad_embed = nn.Parameter(torch.zeros((2, image_dim))) + + @property + def dtype(self) -> torch.dtype: + return self.image_vit.patch_embedding.weight.dtype + + @property + def device(self) -> torch.device: + return self.image_vit.patch_embedding.weight.device + + def encode_image(self, images: torch.Tensor) -> torch.Tensor: + """ + : param images: (batch_size, num_crops, num_patch, n_pixels) + """ + B, T, N, D = images.shape + + mask = ~torch.all( + images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True) + + images = images.view(B * T, N, D) + image_features = self.image_vit(images) + + if self.vit_layers is not None: + features = [] + for layer in self.vit_layers: + features.append(image_features[layer]) + image_features = torch.cat(features, dim=-1) + else: + image_features = image_features[-1] + + if self.num_prefix_tokens > 0: + image_features = image_features[:, 1:] + + image_features = image_features * mask + image_features = image_features.view(B, T, N, -1) + + return image_features + + def forward( + self, images: torch.Tensor, image_masks: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + + # image_features: (batch_size, num_crops(=num_image), num_patch, nximage_emb_dim) # noqa: E501 + batch_size, num_image = images.shape[:2] + images = images.to(device=self.device, dtype=self.dtype) + image_features = self.encode_image(images) + + og_dtype = image_features.dtype + assert image_masks is not None + pad_embed = self.pad_embed[:, None, None, None, :] + all_pad = image_masks == 0 + partial_pad = torch.logical_and( + image_masks < 1, + torch.logical_not(all_pad)).to(dtype=torch.float32) + all_pad = all_pad.to(dtype=torch.float32) + image_features = image_features + pad_embed[0] * torch.unsqueeze( + all_pad, -1) + image_features = image_features + pad_embed[1] * torch.unsqueeze( + partial_pad, -1) + + image_features = image_features.to(og_dtype) + + image_features = image_features.reshape( + (batch_size, num_image) + self.image_num_patch + (-1, ), ) + + if self.image_num_patch[0] % 2 == 1: + # Pad so we can still pool 2x2 patches + image_features = F.pad( + image_features, + (0, 0, 0, 1, 0, 1, 0, 0, 0, 0), + ) + + # image pooling + image_features = rearrange( + image_features, + 'b n (h dh) (w dw) c -> (b n h w) (dh dw) c', + dh=2, + dw=2, + ) + + query = image_features.mean(-2, keepdim=True) + image_features = self.image_pooling_2d(query, image_features) + + h, w = self.llm_patches_per_crop + image_features = image_features.view(batch_size, num_image, h * w, -1) + + image_features = self.image_projector(image_features) + + # image_features: (batch_size, num_image, num_patch, d_model) + return image_features + + +class MolmoModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.embedding_size = config.embedding_size or config.vocab_size + self.embedding_size += ADDITIONAL_VOCAB_SIZE + self.embed_tokens = VocabParallelEmbedding( + self.embedding_size, + config.hidden_size, + quant_config=quant_config, + ) + + decoder_layer = MolmoDecoderNormAfterLayer if config.norm_after \ + else MolmoDecoderLayer + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) + + assert config.layer_norm_type == "rms" + self.norm = RMSNorm(config.hidden_size, config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + return hidden_states + + +cached_get_processor = lru_cache(get_processor) + + +def get_num_patches(num_tiles: int, crop_patches: int, left_margin: int, + right_margin: int, pooling_size: int) -> int: + crop_window_patches = crop_patches - (left_margin + right_margin) + if num_tiles > 1: + left_crop_window_patches = (crop_window_patches + left_margin + + pooling_size - + 1) // pooling_size * pooling_size + middle_crop_window_patches = (crop_window_patches + pooling_size - + 1) // pooling_size * pooling_size + right_crop_window_patches = (crop_window_patches + right_margin + + pooling_size - + 1) // pooling_size * pooling_size + return left_crop_window_patches + ( + num_tiles - + 2) * middle_crop_window_patches + right_crop_window_patches + else: + single_crop_window_patches = (crop_patches + pooling_size - + 1) // pooling_size * pooling_size + return single_crop_window_patches + + +def get_tokens(tiling_h: int, tiling_w: int, crop_patches: int, + left_margin: int, right_margin: int, pooling_size: int) -> int: + h = get_num_patches(tiling_h, crop_patches, left_margin, right_margin, + pooling_size) + w = get_num_patches(tiling_w, crop_patches, left_margin, right_margin, + pooling_size) + per_row = w // pooling_size + 1 + joint = per_row * (h // pooling_size) + 2 + image_token_length = (crop_patches + pooling_size - 1) // pooling_size + resize = (image_token_length + 1) * image_token_length + 2 + return resize + joint + + +def get_max_tokens(max_crops: int, crop_patches: int, left_margin: int, + right_margin: int, pooling_size: int) -> int: + tilings = [] + for i in range(1, max_crops + 1): + for j in range(1, max_crops + 1): + if i * j <= max_crops: + tilings.append((i, j)) + tokens = [ + get_tokens(tilings[i][0], tilings[i][1], crop_patches, left_margin, + right_margin, pooling_size) for i in range(len(tilings)) + ] + return max(tokens) + + +def get_max_molmo_image_tokens(ctx: InputContext) -> int: + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) + image_processor = processor.image_processor + max_llm_image_tokens = get_max_tokens( + image_processor.max_crops, + image_processor.base_image_input_size[0] // + image_processor.image_patch_size, + image_processor.overlap_margins[0], + image_processor.overlap_margins[1], + 2, + ) + return max_llm_image_tokens + + +# NOTE: preprocessing for the image data has been included in the +# 'input_processor_for_molmo' function +def image_input_mapper_for_molmo( + ctx: InputContext, + data: object, +): + return MultiModalInputs(data) + + +def dummy_data_for_molmo(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) + image_processor = processor.image_processor + + base_image_input_d = image_processor.image_patch_size + left_margin, right_margin = image_processor.overlap_margins + max_crops = image_processor.max_crops + + # Assume: prompt_token_ids always starts with bos_token_id followed image tokens # noqa: E501 + max_llm_image_tokens = get_max_molmo_image_tokens(ctx) + if seq_len - max_llm_image_tokens - 1 < 0: + raise RuntimeError( + f"Molmo cannot process {max_crops} crops in a prompt, " + "please increase max_model_len or reduce number of crops") + + # The vertical image has the maximum number of image tokens due to column tokens. # noqa: E501 + tiling = (max_crops, 1) + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = image_processor.base_image_input_size[ + 0] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + + h = crop_window_size * tiling[0] + total_margin_pixels + w = crop_window_size * tiling[1] + total_margin_pixels + + dummy_image = Image.new("RGB", (w, h), color="red") + + out = processor.process("dummy prompt", dummy_image) + + token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + out["input_ids"][:1 + max_llm_image_tokens]) + token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - max_llm_image_tokens - 1) + dummy_seqdata = SequenceData(token_ids) + dummy_imgdata = { + "images": out["images"], + "image_input_idx": out["image_input_idx"], + } + if "image_masks" in out: + dummy_imgdata["image_masks"] = out["image_masks"] + dummy_imgdata["seq_len"] = torch.tensor(seq_len, dtype=torch.long) + return dummy_seqdata, {"image": dummy_imgdata} + + +def pad_images( + max_total_crops: int, + images: torch.Tensor, + image_input_idx: torch.Tensor, + image_masks: Optional[torch.Tensor] = None, +): + n = max_total_crops - images.shape[0] + images = F.pad(images, (0, 0, 0, 0, 0, n), value=-1) + image_input_idx = F.pad(image_input_idx, (0, 0, 0, n), value=-1) + if image_masks is not None: + image_masks = F.pad(image_masks, (0, 0, 0, n), value=-1) + return images, image_input_idx, image_masks + + +def input_processor_for_molmo(ctx: InputContext, llm_inputs: LLMInputs): + prompt = llm_inputs["prompt"] + multi_modal_data = llm_inputs.get("multi_modal_data") + image = multi_modal_data.get("image") + processor = cached_get_processor(ctx.model_config.model, + trust_remote_code=True, + revision=ctx.model_config.code_revision) + + # NOTE: message formatting for raw text prompt is only applied for + # offline inference; for online inference, the prompt is always in + # instruction format and tokenized. + if prompt is not None and re.match(r"^User:[\s\S]*?(Assistant:)*$", + prompt): + out = processor.process(prompt, image, message_format="none") + elif prompt is not None: + out = processor.process(prompt, image) + else: + out = processor.process(None, + image, + tokens=llm_inputs["prompt_token_ids"]) + + image_processor = processor.image_processor + max_total_crops = 1 + image_processor.max_crops + if image is not None: + images, image_input_idx, image_masks = pad_images( + max_total_crops, + out["images"], + out["image_input_idx"], + out.get("image_masks"), + ) + else: + base_image_input_size = image_processor.base_image_input_size + image_patch_size = image_processor.image_patch_size + image_num_patch = ( + base_image_input_size[0] // image_patch_size, + base_image_input_size[1] // image_patch_size, + ) + n_pixels = image_patch_size * image_patch_size * 3 + n_patches = image_num_patch[0] * image_num_patch[1] + + image_length_w = image_processor.image_token_length_w + image_length_h = image_processor.image_token_length_h + tokens_per_image = image_length_w * image_length_h + images = torch.full( + (max_total_crops, n_patches, n_pixels), + -1, + dtype=torch.float32, + ) + image_input_idx = torch.full( + (max_total_crops, tokens_per_image), + -1, + dtype=torch.int32, + ) + if image_processor.image_padding_mask: + image_masks = torch.full( + (max_total_crops, n_patches), + -1, + dtype=torch.float32, + ) + + image_data = dict( + images=images, + image_input_idx=image_input_idx, + ) + if image_masks is not None: + image_data["image_masks"] = image_masks + + image_data["seq_len"] = torch.tensor(len(out["input_ids"]), + dtype=torch.long) + + multi_modal_data = dict(image=image_data) + + return LLMInputs( + prompt_token_ids=out["input_ids"], + prompt=llm_inputs["prompt"], + multi_modal_data=multi_modal_data, + ) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(image_input_mapper_for_molmo) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_molmo_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_molmo) +@INPUT_REGISTRY.register_input_processor(input_processor_for_molmo) +class MolmoForCausalLM(nn.Module, SupportsMultiModal): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: Optional[MultiModalConfig] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[Mapping[str, Any]] = None, + ) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + vision_config = VisionBackboneConfig() + self.vision_backbone = MolmoVisionBackbone(config, vision_config, + quant_config) + self.model = MolmoModel(config, cache_config, quant_config) + + if self.config.weight_tying: + self.lm_head = self.model.transformer.wte + else: + self.lm_head = ParallelLMHead( + config.embedding_size or config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + + self.logits_processor = LogitsProcessor(config.embedding_size + or config.vocab_size) + self.sampler = Sampler() + + def _parse_and_validate_image_input( + self, + **kwargs: object, + ) -> Optional[MolmoImageInputs]: + images = kwargs.pop("images", None) + image_masks = kwargs.pop("image_masks", None) + if images is None: + return None + + image_input_idx = kwargs.pop("image_input_idx", None) + seq_len = kwargs.pop("seq_len", None) + if image_input_idx is None: + raise ValueError("image_input_idx is required for Molmo model.") + if seq_len is None: + raise ValueError("seq_len is required for Molmo model.") + if not isinstance(seq_len, torch.Tensor): + seq_len = torch.tensor(seq_len) + + return MolmoImageInputs( + images=images, + image_input_idx=image_input_idx, + seq_len=seq_len, + image_masks=image_masks, + ) + + def _process_image_input( + self, + image_input: MolmoImageInputs, + ) -> torch.Tensor: + + image_features = self.vision_backbone( + images=image_input["images"], + image_masks=image_input["image_masks"], + ) + + return image_features + + def _merge_multimodal_embeddings( + self, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor, + image_input_idx: torch.Tensor, + seq_len: Union[torch.Tensor, List[torch.Tensor]], + ) -> torch.Tensor: + batch_size, num_image, num_patch = image_features.shape[:3] + assert image_input_idx.shape == (batch_size, num_image, num_patch) + + image_features = image_features.to(inputs_embeds.device) + seq_len = seq_len.to(inputs_embeds.device) + + # insert the image feature into the embedding. + image_features = image_features.view(batch_size, num_image * num_patch, + -1) + image_input_idx = image_input_idx.view(batch_size, + num_image * num_patch) + + valid = image_input_idx >= 0 + image_features = image_features * valid[:, :, None].to( + image_features.dtype) + image_features = image_features.view( + batch_size * num_image * num_patch, -1).contiguous() + + image_input_idx = image_input_idx * valid.to(image_input_idx.dtype) + offset = torch.cat( + [seq_len.new_zeros( + (1)), seq_len.cumsum(dim=0)[:-1]], dim=0)[:, None] + image_input_idx = image_input_idx + offset.to(image_input_idx.dtype) + image_input_idx = image_input_idx.flatten()[:, None] + mat = image_input_idx == torch.arange( + seq_len.sum().item(), device=inputs_embeds.device)[None, :] + mat = mat.to(image_features.dtype) + + inputs_embeds = inputs_embeds + torch.einsum('nd,nm->md', + image_features, mat) + + return inputs_embeds + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.LongTensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + **kwargs: object, + ) -> SamplerOutput: + + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + inputs_embeds = self.model.embed_tokens(input_ids) + image_features = self._process_image_input(image_input) + + inputs_embeds = self._merge_multimodal_embeddings( + inputs_embeds, + image_features, + image_input["image_input_idx"], + image_input["seq_len"], + ) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_mapping = [ + ("model.transformer.ln_f.weight", "model.norm.weight"), + ("attn_out", "self_attn.o_proj"), + ("att_proj", "self_attn.qkv_proj"), + ("q_norm", "self_attn.q_norm"), + ("k_norm", "self_attn.k_norm"), + ("attn_norm", "input_layernorm"), + ("ff_norm", "post_attention_layernorm"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + embedding_weight = dict() + projector_weight = dict() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + + if "wte.embedding" in name: + embedding_weight["embedding"] = loaded_weight + continue + + if "wte.new_embedding" in name: + embedding_weight["new_embedding"] = loaded_weight + continue + + if "vision_backbone" in name: + if name.startswith("model"): + name = name[len("model."):] + if 'image_projector' in name: + if 'w1' in name: + projector_weight['gate_proj'] = loaded_weight + elif 'w3' in name: + projector_weight['up_proj'] = loaded_weight + elif 'w2' in name: + projector_weight['down_proj'] = loaded_weight + else: + raise ValueError( + f"Unexpected projector weight: {name}") + continue + else: + if "transformer.blocks" in name: + name = name.replace("transformer.blocks", "layers") + + if "ff_proj" in name: + name = name.replace("ff_proj", "mlp.gate_up_proj") + assert 'weight' in name + up_weight, gate_weight = loaded_weight.chunk(2, dim=0) + loaded_weight = torch.cat([gate_weight, up_weight], dim=0) + + elif "ff_out" in name: + if "layers" in name: + name = name.replace("ff_out", "mlp.down_proj") + else: + # lm head + name = name.replace("model.transformer.ff_out", + "lm_head") + + else: + for (param_name, weight_name) in params_mapping: + if param_name in name: + name = name.replace(param_name, weight_name) + break + + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + except KeyError: + raise ValueError(f"Unexpected weight: {name}") from None + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + gate_up_proj_weight = torch.cat( + [projector_weight["gate_proj"], projector_weight["up_proj"]], + dim=0) + name = "vision_backbone.image_projector.gate_up_proj.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, gate_up_proj_weight) + + down_proj_weight = projector_weight["down_proj"] + name = "vision_backbone.image_projector.down_proj.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, down_proj_weight) + + embedding_weight = torch.cat( + [embedding_weight["embedding"], embedding_weight["new_embedding"]], + dim=0) + name = "model.embed_tokens.weight" + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, embedding_weight) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py new file mode 100644 index 0000000..e3d3937 --- /dev/null +++ b/vllm/model_executor/models/mpt.py @@ -0,0 +1,327 @@ +# coding=utf-8 +# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.mpt import MPTConfig + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +def _get_alibi_slopes( + total_num_heads: int, + alibi_bias_max: int, +) -> torch.Tensor: + next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) + m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) + m = m.mul(alibi_bias_max / next_power_of_2) + slopes = 1.0 / torch.pow(2, m) + if next_power_of_2 != total_num_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads] + return slopes + + +class MPTAttention(nn.Module): + + def __init__( + self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.alibi_bias_max = config.attn_config["alibi_bias_max"] + if "kv_n_heads" in config.attn_config: + self.total_num_kv_heads = config.attn_config['kv_n_heads'] + else: + self.total_num_kv_heads = self.total_num_heads + assert not config.attn_config["prefix_lm"] + assert config.attn_config["alibi"] + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=not config.no_bias, + quant_config=quant_config, + ) + if self.qk_ln: + self.q_ln = nn.LayerNorm(self.d_model) + self.k_ln = nn.LayerNorm(self.d_model) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=not config.no_bias, + quant_config=quant_config, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads, + self.alibi_bias_max) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + self.head_dim = self.d_model // self.total_num_heads + scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + del position_ids # unused. + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.qk_ln: + q = self.q_ln(q) + k = self.k_ln(k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class MPTMLP(nn.Module): + + def __init__( + self, + config: MPTConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.d_model + expansion_ratio = config.expansion_ratio + intermediate_size = expansion_ratio * hidden_size + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=not config.no_bias, + quant_config=quant_config, + ) + self.act = get_act_fn("gelu", quant_config, intermediate_size) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=not config.no_bias, + quant_config=quant_config, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.up_proj(x) + x = self.act(x) + x, _ = self.down_proj(x) + return x + + +class MPTBlock(nn.Module): + + def __init__( + self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + hidden_size = config.d_model + self.norm_1 = nn.LayerNorm(hidden_size) + self.attn = MPTAttention(config, cache_config, quant_config) + self.norm_2 = nn.LayerNorm(hidden_size) + self.ffn = MPTMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + x = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=x, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + x + x = self.norm_2(hidden_states) + x = self.ffn(x) + hidden_states = hidden_states + x + return hidden_states + + +class MPTModel(nn.Module): + + def __init__( + self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + assert config.embedding_fraction == 1.0 + assert config.norm_type == "low_precision_layernorm" + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: MPTBlock(config, cache_config, quant_config), + prefix=f"{prefix}.blocks") + self.norm_f = nn.LayerNorm(config.d_model) + if config.no_bias: + for module in self.modules(): + if hasattr(module, "bias") and isinstance( + module.bias, nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + block = self.blocks[i] + hidden_states = block( + position_ids, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class MPTForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: MPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + assert config.tie_word_embeddings + self.quant_config = quant_config + + self.transformer = MPTModel(config, cache_config, quant_config) + self.lm_head = self.transformer.wte + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py new file mode 100644 index 0000000..14515e1 --- /dev/null +++ b/vllm/model_executor/models/nemotron.py @@ -0,0 +1,525 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Nemotron model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import NemotronConfig + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +# The architecture is pretty similar to Llama, with these changes: +# - There is no gate_proj, just up_proj +# - Normal LayerNorm (with a +1 to the weights) instead of RMSNorm +# - Squared ReLU instead of SwiGLU +# - Adds a partial_rotary_factor to RoPE + + +def _cast_if_autocast_enabled(*args): + if not torch.is_autocast_enabled(): + return args + else: + return torch.cuda.amp.autocast_mode._cast( + args, torch.get_autocast_gpu_dtype()) + + +class NemotronLayerNorm1P(nn.LayerNorm): + + def __init__(self, + normalized_shape: Union[int, List[int], torch.Size], + eps: float = 1e-5, + elementwise_affine: bool = True, + bias: bool = True, + device=None, + dtype=None): + super().__init__(normalized_shape, eps, elementwise_affine, bias, + device, dtype) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if residual is not None: + x = x + residual + residual = x + args = _cast_if_autocast_enabled(x, self.normalized_shape, + self.weight + 1, self.bias, self.eps) + with torch.cuda.amp.autocast(enabled=False): + x = torch.nn.functional.layer_norm(*args) + return x if residual is None else (x, residual) + + +class NemotronMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.up_proj = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_fn = get_act_fn(hidden_act) + + def forward(self, x): + up, _ = self.up_proj(x) + x = self.act_fn(up) + x, _ = self.down_proj(x) + return x + + +class NemotronAttention(nn.Module): + + def __init__( + self, + config: NemotronConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=self.partial_rotary_factor, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class NemotronDecoderLayer(nn.Module): + + def __init__( + self, + config: NemotronConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = NemotronAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = NemotronMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, + eps=config.norm_eps) + self.post_attention_layernorm = NemotronLayerNorm1P( + config.hidden_size, eps=config.norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class NemotronModel(nn.Module): + + def __init__( + self, + config: NemotronConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: NemotronDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = NemotronLayerNorm1P(config.hidden_size, + eps=config.norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "up_proj", "down_proj", "embed_tokens", "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + } + + def __init__( + self, + config: NemotronConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + assert isinstance(config, NemotronConfig) + + self.config = config + self.lora_config = lora_config + + self.model = NemotronModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py new file mode 100644 index 0000000..a52e3cb --- /dev/null +++ b/vllm/model_executor/models/nvlm_d.py @@ -0,0 +1,64 @@ +# adapted from https://huggingface.co/nvidia/NVLM-D-72B/blob/main/modeling_nvlm_d.py +# -------------------------------------------------------- +# NVLM-D +# Copyright (c) 2024 NVIDIA +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.inputs import INPUT_REGISTRY +from vllm.multimodal import MULTIMODAL_REGISTRY + +from .intern_vit import InternVisionModel +from .internvl import (InternVLChatModel, InternVLInputPipeline, + get_max_internvl_image_tokens) + +IMG_START = '<|vision_start|>' +IMG_END = '<|vision_end|>' +IMG_CONTEXT = '<|vision_pad|>' + + +class NVLMInputPipeline(InternVLInputPipeline): + + def _create_image_prompt(self, feature_size: int, num_patches: int) -> str: + tile_pos_identifiers = ([f"" + for i in range(1, num_patches)] + + [""]) + context_size = feature_size // num_patches + + return '' + ''.join( + tile_pos_identifier + self.img_context_token * context_size + for tile_pos_identifier in tile_pos_identifiers) + '' + + +input_pipeline = NVLMInputPipeline(IMG_START, IMG_END, IMG_CONTEXT) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) +@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data) +@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor) +class NVLM_D_Model(InternVLChatModel): + + def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential: + vit_hidden_size = config.vision_config.hidden_size + llm_intermediate_size = config.text_config.intermediate_size + llm_hidden_size = config.text_config.hidden_size + + return nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2), + nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2, + llm_intermediate_size, + bias=False), + nn.GELU(), + nn.Linear(llm_intermediate_size, llm_hidden_size, bias=False), + ) + + def _init_vision_model(self, config: PretrainedConfig, + num_hidden_layers: int): + # We added additional dummy heads to the original num of heads to make + # the number of heads divisible by 8. + return InternVisionModel(config.vision_config, + num_hidden_layers_override=num_hidden_layers, + num_dummy_heads=7) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py new file mode 100644 index 0000000..5ca7c66 --- /dev/null +++ b/vllm/model_executor/models/olmo.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + # shape: (batch_size, seq_len, d_model) + hidden_states = self.layers[i]( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoForCausalLM(nn.Module, SupportsPP): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.model = OlmoModel(config, cache_config, quant_config) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py new file mode 100644 index 0000000..a1ba80e --- /dev/null +++ b/vllm/model_executor/models/olmoe.py @@ -0,0 +1,445 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import print_warning_once + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class OlmoeMoE(nn.Module): + """A tensor-parallel MoE implementation for Olmoe that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + quant_config=None) + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + return final_hidden_states.view(orig_shape) + + +class OlmoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 4096, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.q_norm = RMSNorm(hidden_size, eps=1e-5) + self.k_norm = RMSNorm(hidden_size, eps=1e-5) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 4096) + + self.self_attn = OlmoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + + self.mlp = OlmoeMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class OlmoeModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoeDecoderLayer(config, int( + prefix.split(".")[-1]), cache_config, quant_config), + prefix=f"{prefix}.layers") + self.norm = RMSNorm(config.hidden_size, eps=1e-5) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class OlmoeForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = OlmoeModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py new file mode 100644 index 0000000..3bcdb0d --- /dev/null +++ b/vllm/model_executor/models/opt.py @@ -0,0 +1,416 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py +# Copyright 2023 The vLLM team. +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OPT model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OPTConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the + # embedding ids by 2 and adjust num_embeddings appropriately. Other + # models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, positions: torch.Tensor): + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + total_num_heads = num_heads + assert num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = embed_dim // total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + embed_dim, + self.head_dim, + total_num_heads, + bias=bias, + quant_config=quant_config, + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + quant_config=quant_config, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class OPTDecoderLayer(nn.Module): + + def __init__( + self, + config: OPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + bias=config.enable_bias, + cache_config=cache_config, + quant_config=quant_config, + ) + self.do_layer_norm_before = config.do_layer_norm_before + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, + elementwise_affine=config.layer_norm_elementwise_affine) + self.fc1 = ColumnParallelLinear( + self.embed_dim, + config.ffn_dim, + bias=config.enable_bias, + quant_config=quant_config, + ) + self.activation_fn = get_act_fn(config.activation_function, + quant_config, config.ffn_dim) + self.fc2 = RowParallelLinear( + config.ffn_dim, + self.embed_dim, + bias=config.enable_bias, + quant_config=quant_config, + ) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, + elementwise_affine=config.layer_norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class OPTDecoder(nn.Module): + + def __init__( + self, + config: OPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.word_embed_proj_dim, + ) + # Positional embeddings are replicated (not sharded). + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + + # Project out & in will be replicated if they exist. + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = ReplicatedLinear(config.hidden_size, + config.word_embed_proj_dim, + bias=False, + quant_config=quant_config) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = ReplicatedLinear(config.word_embed_proj_dim, + config.hidden_size, + bias=False, + quant_config=quant_config) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to + # keep backward compatibility with checkpoints that have been fine-tuned + # before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, + elementwise_affine=config.layer_norm_elementwise_affine) + else: + self.final_layer_norm = None + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OPTDecoderLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers") + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + pos_embeds = self.embed_positions(positions) + if self.project_in is not None: + inputs_embeds, _ = self.project_in(inputs_embeds) + hidden_states = inputs_embeds + pos_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + if self.project_out is not None: + hidden_states, _ = self.project_out(hidden_states) + return hidden_states + + +class OPTModel(nn.Module): + + def __init__( + self, + config: OPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.decoder = OPTDecoder(config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.decoder.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + return self.decoder(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + +class OPTForCausalLM(nn.Module, SupportsPP): + + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + } + default_bitsandbytes_target_modules = [ + ".q_proj.", ".k_proj.", ".v_proj.", ".out_proj.", ".fc1.", ".fc2." + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".out_proj.", ".fc2."] + + def __init__( + self, + config: OPTConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = OPTModel(config, cache_config, quant_config) + if self.config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.word_embed_proj_dim) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "lm_head.weight" in name and self.config.tie_word_embeddings: + continue + if name.startswith("decoder."): + name = "model." + name + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py new file mode 100644 index 0000000..0913193 --- /dev/null +++ b/vllm/model_executor/models/orion.py @@ -0,0 +1,363 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/modeling_orion.py +# Copyright (c) OrionStar Inc. +# LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE +"""Inference-only Orion-14B model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class OrionMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OrionAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OrionDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = OrionAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + self.mlp = OrionMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states, None + + +class OrionModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OrionDecoderLayer( + config, + cache_config, + quant_config, + ), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OrionForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = OrionModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py new file mode 100644 index 0000000..99d000e --- /dev/null +++ b/vllm/model_executor/models/paligemma.py @@ -0,0 +1,295 @@ +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) + +import torch +from torch import nn +from transformers import PaliGemmaConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.models.gemma import GemmaForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsMultiModal, SupportsPP +from .siglip import (SiglipVisionModel, dummy_image_for_siglip, + dummy_seq_data_for_siglip, get_max_siglip_image_tokens) +from .utils import AutoWeightsLoader, merge_multimodal_embeddings + +logger = init_logger(__name__) + + +class PaliGemmaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size * num_images, num_channels, height, width)`""" + + +class PaliGemmaImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +PaliGemmaImageInputs = Union[PaliGemmaImagePixelInputs, + PaliGemmaImageEmbeddingInputs] + + +def get_max_paligemma_image_tokens(ctx: InputContext): + hf_config = ctx.get_hf_config(PaliGemmaConfig) + vision_config = hf_config.vision_config + + return get_max_siglip_image_tokens(vision_config) + + +def dummy_data_for_paligemma(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + hf_config = ctx.get_hf_config(PaliGemmaConfig) + vision_config = hf_config.vision_config + num_images = mm_counts["image"] + + seq_data = dummy_seq_data_for_siglip( + vision_config, + seq_len, + num_images, + image_token_id=hf_config.image_token_index, + ) + + mm_data = dummy_image_for_siglip(vision_config, num_images) + return seq_data, mm_data + + +def input_processor_for_paligemma(ctx: InputContext, llm_inputs: LLMInputs): + + """ + The correct prompt format needs to be: + '' * image_feature_size + '' + prompt + '\n' + + See https://github.com/huggingface/transformers/blob/25245ec26dc29bcf6102e1b4ddd0dfd02e720cf5/src/transformers/models/paligemma/processing_paligemma.py#L55 + """ # noqa + + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_config(PaliGemmaConfig) + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + image_feature_size = hf_config.text_config.num_image_tokens + image_token_str = tokenizer.decode(hf_config.image_token_index) + bos_token = tokenizer.decode(hf_config.bos_token_id) + image_token_str_pad = image_token_str * image_feature_size + image_token_ids_pad = [hf_config.image_token_index] * image_feature_size + + orig_prompt = llm_inputs.get("prompt") + orig_prompt_ids = llm_inputs.get("prompt_token_ids") + + if orig_prompt is not None and image_token_str in orig_prompt: + logger.warning( + "The image token '%s' was detected in the prompt and " + "will be removed. Please follow the proper prompt format" + " documented on HuggingFace.", image_token_str) + orig_prompt = orig_prompt.replace(image_token_str, "") + orig_prompt_ids.remove(hf_config.image_token_index) + + new_prompt = f"{image_token_str_pad}{bos_token}{orig_prompt}\n" + new_token_ids = image_token_ids_pad + orig_prompt_ids + [108] #newline + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class PaliGemmaMultiModalProjector(nn.Module): + + def __init__(self, vision_hidden_size: int, projection_dim: int): + super().__init__() + + self.linear = nn.Linear(vision_hidden_size, projection_dim, bias=True) + + def forward(self, image_features: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear(image_features) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) +@INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, + config: PaliGemmaConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + self.vision_tower = SiglipVisionModel(config.vision_config) + self.multi_modal_projector = PaliGemmaMultiModalProjector( + vision_hidden_size=config.vision_config.hidden_size, + projection_dim=config.vision_config.projection_dim) + + self.quant_config = quant_config + self.language_model = GemmaForCausalLM(config.text_config, + cache_config, quant_config) + logit_scale = getattr(config, "logit_scale", 1.0) + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler + + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: + h = w = self.config.vision_config.image_size + expected_dims = (3, h, w) + actual_dims = tuple(data.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("batch_size", *map(str, expected_dims)) + raise ValueError( + f"The expected shape of pixel values is {expected_expr}. " + f"You supplied {tuple(data.shape)}.") + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[PaliGemmaImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, torch.Tensor): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + + return PaliGemmaImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(pixel_values), + ) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + + return PaliGemmaImageEmbeddingInputs( + type="image_embeds", + data=image_embeds, + ) + + raise AssertionError("This line should be unreachable.") + + def _image_pixels_to_features( + self, + vision_tower: SiglipVisionModel, + pixel_values: torch.Tensor, + ) -> torch.Tensor: + + target_dtype = vision_tower.get_input_embeddings().weight.dtype + image_features = vision_tower(pixel_values.to(dtype=target_dtype)) + + return image_features + + def _process_image_input( + self, + image_input: PaliGemmaImageInputs, + ) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + return image_input["data"] + + assert self.vision_tower is not None + pixel_values = image_input["data"] + image_features = self._image_pixels_to_features( + self.vision_tower, + pixel_values, + ) + + return self.multi_modal_projector(image_features) + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + parsed_image_input = self._parse_and_validate_image_input(**kwargs) + + if parsed_image_input is not None: + vision_embeddings = self._process_image_input( + parsed_image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * ( + self.config.hidden_size**-0.5) + + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py new file mode 100644 index 0000000..b625d19 --- /dev/null +++ b/vllm/model_executor/models/persimmon.py @@ -0,0 +1,354 @@ +# coding=utf-8 +# adapted from https://github.com/huggingface/transformers/blob/v4.39.3/src/transformers/models/persimmon/modeling_persimmon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only persimmon model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PersimmonConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class PersimmonMLP(nn.Module): + + def __init__(self, + config: PersimmonConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size, + config.intermediate_size, + quant_config=quant_config) + self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, + config.hidden_size, + quant_config=quant_config) + self.act = get_act_fn(config.hidden_act, quant_config) + + def forward(self, hidden_states) -> torch.Tensor: + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class PersimmonAttention(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + tensor_parallel_world_size = get_tensor_model_parallel_world_size() + + self.hidden_size = config.hidden_size + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tensor_parallel_world_size + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + assert (self.head_dim * self.total_num_heads) == self.hidden_size + assert self.total_num_heads % tensor_parallel_world_size == 0 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + self.is_qk_layernorm = config.qk_layernorm + + if self.is_qk_layernorm: + self.q_layernorm = nn.LayerNorm(self.head_dim) + self.k_layernorm = nn.LayerNorm(self.head_dim) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=int(self.partial_rotary_factor * self.head_dim), + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + def _split_heads(self, x: torch.Tensor) -> torch.Tensor: + # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads, self.head_dim) + + def _merge_heads(self, x: torch.Tensor) -> torch.Tensor: + # [seq_length, num_heads, head_dim] -> [seq_length, hidden_size] + seq_length = x.shape[0] + return x.view(seq_length, self.num_heads * self.head_dim) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # [seq_length, 3 x hidden_size] + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + + if self.is_qk_layernorm: + # [seq_length, num_heads, head_dim] + q = self._split_heads(q) + k = self._split_heads(k) + + q = self.q_layernorm(q) + k = self.k_layernorm(k) + + q = self._merge_heads(q) + k = self._merge_heads(k) + + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.dense(attn_output) + return output + + +class PersimmonDecoderLayer(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = PersimmonAttention(config=config, + cache_config=cache_config, + quant_config=quant_config) + self.mlp = PersimmonMLP(config, quant_config=quant_config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = hidden_states + residual + + outputs = hidden_states + return outputs + + +class PersimmonModel(nn.Module): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PersimmonDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + hidden_states = self.layers[i]( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class PersimmonForCausalLM(nn.Module, SupportsPP): + + def __init__(self, + config: PersimmonConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.model = PersimmonModel(config, + cache_config=cache_config, + quant_config=quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=False) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ): + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + + if "query_key_value" in name: + # copy from vllm/model_executor/models/bloom.py + # NOTE: Persimmon's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py new file mode 100644 index 0000000..0918f21 --- /dev/null +++ b/vllm/model_executor/models/phi.py @@ -0,0 +1,374 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py +# Copyright 2023 The vLLM team. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PhiConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class PhiAttention(nn.Module): + + def __init__(self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + # pylint: disable=C0103 + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_size, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + quant_config=quant_config, + ) + + scaling = self.head_size**-0.5 + rotary_dim = int(config.partial_rotary_factor * + (config.hidden_size // config.num_attention_heads)) + assert rotary_dim % 2 == 0 + + # pylint: disable=C0301 + # Refer to: + # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 + rope_theta = 10000 + max_position_embeddings = getattr(config, "n_positions", 2048) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_size, + scaling, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.dense(attn_output) + return output + + +class PhiMLP(nn.Module): + + def __init__(self, + config: PhiConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + + n_inner = getattr(config, "n_inner", None) + n_inner = n_inner if n_inner is not None else 4 * config.hidden_size + + self.fc1 = ColumnParallelLinear( + config.hidden_size, + n_inner, + quant_config=quant_config, + ) + self.fc2 = RowParallelLinear( + n_inner, + config.hidden_size, + quant_config=quant_config, + ) + self.act = get_act_fn(config.hidden_act, quant_config, n_inner) + + def forward(self, hidden_states): + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class PhiLayer(nn.Module): + + def __init__(self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.self_attn = PhiAttention(config, cache_config, quant_config) + self.mlp = PhiMLP(config, quant_config) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_outputs = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + return hidden_states + + +class PhiModel(nn.Module): + + def __init__(self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.quant_config = quant_config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers") + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "dense", + "fc1", + "fc2", + ] + + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + } + default_bitsandbytes_target_modules = [ + ".q_proj.", ".k_proj.", ".v_proj.", ".fc1.", ".fc2.", ".dense." + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".fc2.", ".dense."] + + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PhiConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + + self.config = config + # lm_head use bias, cannot share word embeddings + assert not config.tie_word_embeddings + self.lora_config = lora_config + + self.quant_config = quant_config + + self.model = PhiModel(config, cache_config, quant_config) + + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata, self.lm_head.bias) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v") + ] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # pylint: disable=E1136 + + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/phi3.py b/vllm/model_executor/models/phi3.py new file mode 100644 index 0000000..02b2ff0 --- /dev/null +++ b/vllm/model_executor/models/phi3.py @@ -0,0 +1,17 @@ +# coding=utf-8 +# Adapted from llama.py +"""Inference-only Phi3 model code inherit from Llama.py""" + +from vllm.model_executor.models.llama import LlamaForCausalLM + + +class Phi3ForCausalLM(LlamaForCausalLM): + + packed_modules_mapping = { + "qkv_proj": [ + "qkv_proj", + ], + "gate_up_proj": [ + "gate_up_proj", + ], + } diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py new file mode 100644 index 0000000..c9af975 --- /dev/null +++ b/vllm/model_executor/models/phi3_small.py @@ -0,0 +1,473 @@ +import math +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +def load_column_parallel_weight(param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + tp = get_tensor_model_parallel_world_size() + rk = get_tensor_model_parallel_rank() + assert param.size(0) * tp == loaded_weight.size(0) + s = rk * param.size(0) + e = (rk + 1) * param.size(0) + loaded_weight = loaded_weight[s:e] + assert param.shape == loaded_weight.shape + param.data.copy_(loaded_weight) + + +class HeadMajorQKVParallelLinear(QKVParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +class HeadMajorColumnParallelLinear(MergedColumnParallelLinear): + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor): + return load_column_parallel_weight(param, loaded_weight) + + +# @torch.jit.script +def quick_gelu(x): + return x * torch.sigmoid(1.702 * x) + + +# @torch.jit.script +def gegelu(input, limit: Optional[float] = None): + a_gelu, a_linear = input[..., ::2], input[..., 1::2] + if limit is not None: + a_gelu = torch.where(torch.isinf(a_gelu), a_gelu, + a_gelu.clamp(min=None, max=limit)) + a_linear = torch.where( + torch.isinf(a_linear), + a_linear, + a_linear.clamp(min=-limit, max=limit), + ) + out_gelu = quick_gelu(a_gelu) + return out_gelu * (a_linear + 1) + + +class Phi3SmallMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + assert (self.config.hidden_act == "gegelu" + ), "Only `gegelu` is supported for the 4.7 series of models .." + self.hidden_size = config.hidden_size + self.gegelu_limit = config.gegelu_limit + self.intermediate_size = config.intermediate_size + + self.up_proj = HeadMajorColumnParallelLinear( + self.hidden_size, + 2 * [self.intermediate_size], + bias=True, + quant_config=quant_config, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + gate_up, _ = self.up_proj(x) + x = gegelu(gate_up) + x, _ = self.down_proj(x) + return x + + +class Phi3SmallSelfAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.config = config + self.sparse_block_size = config.blocksparse_block_size + self.homo_heads = config.blocksparse_homo_head_pattern + self.local_blocks = config.blocksparse_num_local_blocks + self.vert_stride = config.blocksparse_vert_stride + + assert (config.blocksparse_block_size == + config.blocksparse_triton_kernel_block_size) + + self.hidden_size = config.hidden_size + # Number of Query Heads + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + # Number of total Key Value Heads before tensor parallel + self.num_key_value_heads = config.num_key_value_heads + self.num_q_per_kv = self.num_heads // self.num_key_value_heads + if self.tp_size > 1: + assert self.num_key_value_heads % self.tp_size == 0 + self.num_kv_heads_per_partion = max( + 1, self.num_key_value_heads // self.tp_size) + self.num_heads_per_partition = self.num_heads // self.tp_size + + self.max_position_embeddings = config.max_position_embeddings + self.rope_embedding_base = config.rope_embedding_base + self.rope_position_scale = config.rope_position_scale + self.is_causal = True + + norm_factor = None + if config.mup_use_scaling: + norm_factor = self.head_dim / config.mup_attn_multiplier + else: + norm_factor = math.sqrt(self.head_dim) + self.scale = 1 / norm_factor + + self.query_key_value = HeadMajorQKVParallelLinear( + self.hidden_size, + self.head_dim, + self.num_heads, + self.num_key_value_heads, + bias=True, + quant_config=quant_config, + ) + + self.dense = RowParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config) + + if getattr(self.config, "rope_scaling", None) is not None: + rope_scaling = self.config.rope_scaling + for key in rope_scaling: + if isinstance(rope_scaling[key], list): + rope_scaling[key] = tuple(rope_scaling[key]) + + if "factor" not in rope_scaling: + rope_scaling["factor"] = self.rope_position_scale + else: + rope_scaling = { + "type": "linear", + "factor": self.rope_position_scale, + } + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_embedding_base, + rope_scaling=rope_scaling, + ) + + # blocksparse params + self.blocksparse_block_size = config.blocksparse_block_size + self.blocksparse_num_local_blocks = config.blocksparse_num_local_blocks + self.blocksparse_vert_stride = config.blocksparse_vert_stride + + use_dense_attn = (getattr(self.config, + "dense_attention_every_n_layers", None) + and (self.layer_idx + 1) % + self.config.dense_attention_every_n_layers == 0) + + bs_params = None + if not use_dense_attn: + bs_params = { + 'max_seqlen': self.max_position_embeddings, + 'num_heads': self.num_heads_per_partition, + "num_kv_heads": self.num_kv_heads_per_partion, + "block_size": self.sparse_block_size, + "local_blocks": self.local_blocks, + "vert_stride": self.vert_stride, + "homo_head": self.homo_heads + } + + self.attn = Attention( + self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + qkv, _ = self.query_key_value(hidden_states) + + qkv = qkv.view(qkv.shape[:-1] + + (-1, (self.num_q_per_kv + 2), self.head_dim)) + q, k, v = qkv.split([self.num_q_per_kv, 1, 1], dim=-2) + + # NOTE: this is required by RotaryEmbed, which indeed does not have to + # TODO: allow 3D QK for rotary forward + q = q.reshape(-1, self.head_dim * self.num_heads_per_partition) + k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) + + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) + output, _ = self.dense(attn_output) + + return output + + +class Phi3SmallDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Phi3SmallSelfAttention(config, + layer_idx, + cache_config=cache_config, + quant_config=quant_config) + self.mlp = Phi3SmallMLP(config, quant_config) + + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Phi3SmallModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.mup_embedding_multiplier = config.mup_embedding_multiplier + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Phi3SmallDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, quant_config), + prefix=f"{prefix}.layers") + + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class Phi3SmallForCausalLM(nn.Module, SupportsPP): + _tied_weights_keys = ["lm_head.weight"] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ): + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Phi3SmallModel(config, cache_config, quant_config) + self.vocab_size = config.vocab_size + self.mup_width_multiplier = config.mup_width_multiplier + self.lm_head = ParallelLMHead( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # tokens in tiktoken but not used + if hasattr(config, 'dummy_token_indices'): + device = self.lm_head.weight.device + self.register_buffer('dummy_token_indices', + torch.LongTensor( + config.dummy_token_indices).to(device), + persistent=False) + else: + self.dummy_token_indices = None + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, value): + self.lm_head = value + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + if self.dummy_token_indices is not None and logits is not None: + logits.index_fill_(-1, self.dummy_token_indices, -torch.inf) + return logits + + def forward( + self, + input_ids: torch.LongTensor, + positions: Optional[torch.LongTensor], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + output_hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + ) + output_hidden_states = output_hidden_states + return output_hidden_states + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits / self.mup_width_multiplier, + sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py new file mode 100644 index 0000000..00a04da --- /dev/null +++ b/vllm/model_executor/models/phi3v.py @@ -0,0 +1,694 @@ +# coding=utf-8 +# Copyright 2024 The vLLM team. +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import itertools +import re +from functools import cached_property, lru_cache +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Tuple, TypedDict, Union) + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from transformers import CLIPVisionConfig, PretrainedConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.models.clip import CLIPVisionModel +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of + +from .clip import dummy_image_for_clip, dummy_seq_data_for_clip +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + merge_multimodal_embeddings) + +logger = init_logger(__name__) + +# Cannot find the following 2 numbers from hf config. +_IMAGE_TOKEN_ID = 32044 + +# Result in the max possible feature size (h:w = 16:1) +MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000 +MAX_IMAGE_FEATURE_SIZE_WIDTH = 50 + +CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0, + hidden_act="quick_gelu", + hidden_size=1024, + image_size=336, + intermediate_size=4096, + num_attention_heads=16, + num_channels=3, + num_hidden_layers=24, + patch_size=14, + projection_dim=768) + + +def _init_img_processor(hf_config: PretrainedConfig): + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + layer_idx = hf_config.img_processor.get('layer_idx', -2) + + # Initialize the CLIP only up to the required feature layer + if layer_idx < 0: + num_hidden_layers = clip_config.num_hidden_layers + \ + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + + img_processor = CLIPVisionModel( + clip_config, num_hidden_layers_override=num_hidden_layers) + + return img_processor + + +class Phi3VImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: Union[torch.Tensor, List[torch.Tensor]] + """ + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` + + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. + """ + + image_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class Phi3VImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + """ + + +Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs] + + +class Phi3ImageEmbeddingBase(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.layer_idx: int + self.type_feature: str + self.img_processor: CLIPVisionModel + + def get_img_features(self, + img_embeds: torch.FloatTensor) -> torch.FloatTensor: + TYPE_FEATURE = self.type_feature + + # NOTE: we skip the step to select the vision feature layer since + # this is already done inside the img_processor + img_feature = self.img_processor(img_embeds) + + if TYPE_FEATURE == "patch": + patch_feature = img_feature[:, 1:] + return patch_feature + + if TYPE_FEATURE == "cls_patch": + return img_feature + + raise NotImplementedError + + +# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py +class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase): + """Phi3 Image embedding with HD transform.""" + + def __init__(self, config: PretrainedConfig) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr( + config, 'n_embd') else config.hidden_size + + self.img_processor = _init_img_processor(config) + + image_dim_out = config.img_processor['image_dim_out'] + self.num_img_tokens = config.img_processor['num_img_tokens'] + + self.image_dim_out = image_dim_out + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = config.embd_layer.get('use_hd_transform', + False) + self.with_learnable_separator = config.embd_layer.get( + 'with_learnable_separator', False) + self.hd_transform_order = config.embd_layer.get( + 'hd_transform_order', 'glb_sub') + # with_hd_transform and with_learnable_separator should have same value + assert self.use_hd_transform and self.with_learnable_separator + + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4])) + self.sub_GN = nn.Parameter( + torch.empty([1, 1, 1, self.image_dim_out * 4])) + + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * 4, dim_projection)] + for _ in range(1, depth): + layers.extend( + [nn.GELU(), + nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + + self.type_feature = config.img_processor.get('type_feature', 'patch') + + def forward(self, pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor) -> torch.FloatTensor: + """ + process image and return vision embeddings. + + pixel_values: (num_images, num_crops, c, h, w) + output: (num_images, num_img_tokens, hidden_size) + """ + num_images, num_crops, c, h, w = pixel_values.shape + pixel_values = pixel_values.flatten(0, 1) + img_features = self.get_img_features(pixel_values) + img_features = img_features.reshape(num_images, num_crops, -1, + self.image_dim_out) + image_features_proj = self.hd_feature_transform( + img_features, image_sizes) + return image_features_proj + + def hd_feature_transform(self, image_features, image_sizes): + """ + image_features: (num_images, num_crops+1, 24*24, 1024) + """ + assert ( + self.hd_transform_order == 'sub_glb' + ), f'hd_transform_order `{self.hd_transform_order}` not implemented' + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + global_image_features = image_features[:, + 0] # (num_images, 24*24, 1024) + # global feature can be viewed as a special HD case with num_crops 1x1 + global_image_features_hd = self.reshape_hd_patches_2x2merge( + global_image_features, 1, 1) + global_image_features_hd_newline = self.add_image_newline( + global_image_features_hd) + + batch_image_features_proj = [] + # need a for loop to process each image because of different image sizes + # (patch arrangement is different for each image) + for i, img_size in enumerate(image_sizes): + h, w = img_size + h_crop = h // 336 + w_crop = w // 336 + num_crops = h_crop * w_crop + + # NOTE: real num_crops is padded + # (num_crops, 24*24, 1024) + sub_image_features = image_features[i, 1:1 + num_crops] + sub_image_features_hd = self.reshape_hd_patches_2x2merge( + sub_image_features, h_crop, w_crop) + sub_image_features_hd_newline = self.add_image_newline( + sub_image_features_hd) + + # [sub features, separator, global features] + image_embeddings = torch.cat([ + sub_image_features_hd_newline.squeeze( + 0), # (h_crop*12*(w_crop*12+1), 4096) + self.glb_GN.squeeze(0), + global_image_features_hd_newline[i], + ]) + img_proj = self.img_projection( + image_embeddings.to(target_device, target_dtype)) + batch_image_features_proj.append(img_proj) + + return batch_image_features_proj + + def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop): + """ + image_features: (num_images*num_crops, 24*24, 1024) + output: (num_images, h_crop*12, w_crop*12, 4096) + where h_crop*w_crop == num_crops + """ + N, L, C = image_features.shape + assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0 + num_images = N // (h_crop * w_crop) + H = int(L**0.5) + image_features_hd = ( + image_features.reshape(N, H, H, C) # N, 24, 24, 1024 + .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024 + .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024 + .reshape(N, -1, 4 * C) # N, 144, 4096 + .reshape(num_images, h_crop, w_crop, H // 2, H // 2, + -1) # n_img, h_crop, w_crop, 12, 12, 4096 + .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096 + .reshape(num_images, h_crop * H // 2, w_crop * H // 2, + 4 * C) # n_img, h_crop*12, w_crop*12, 4096 + ) + return image_features_hd + + def add_image_newline(self, image_features_hd): + """ + image_features_hd: (num_images, h_crop*12, w_crop*12, 4096) + output: (num_images, (h_crop*12) * (w_crop*12+1), 4096) + """ + num_images, h, w, hid_dim = image_features_hd.shape + # add the newline token to the HD image feature patches + newline_embeddings = self.sub_GN.expand(num_images, h, -1, + -1) # (n_img, h, 1, hid_dim) + image_features_hd_newline = torch.cat( + [image_features_hd, newline_embeddings], + dim=2).reshape(num_images, -1, hid_dim) + return image_features_hd_newline + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self) + loader.load_weights(weights) + + +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 +def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): + target_height = int(np.ceil(height / padding_unit) * padding_unit) + top_padding = int((target_height - height) / 2) + bottom_padding = target_height - height - top_padding + padded_width = width + padded_height = height + top_padding + bottom_padding + return padded_width, padded_height + + +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90 +def _calc_hd_transform_size(*, width: int, height: int, hd_num: int): + transposed = False + if width < height: + width, height = height, width + transposed = True + + ratio = width / height + scale = 1 + while scale * np.ceil(scale / ratio) <= hd_num: + scale += 1 + scale -= 1 + + new_width = int(scale * 336) + new_height = int(new_width / ratio) + + padded_width, padded_height = _calc_padded_size(width=new_width, + height=new_height) + + if transposed: + padded_width, padded_height = padded_height, padded_width + + return padded_width, padded_height + + +# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181 +def get_phi3v_image_feature_size( + hf_config: Dict[str, Any], + *, + input_height: int, + input_width: int, + num_crops: int, +) -> int: + if num_crops is None: + num_crops = hf_config.get("num_crops", 16) + new_width, new_height = _calc_hd_transform_size(width=input_width, + height=input_height, + hd_num=num_crops) + + return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \ + + (new_height // 336 + 1) * 12 + + +def get_max_phi3v_image_tokens(ctx: InputContext, + *, + num_crops: Optional[int] = None): + + return get_phi3v_image_feature_size( + ctx.get_hf_image_processor_config(), + input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, + num_crops=num_crops, + ) + + +def dummy_data_for_phi3v(ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], + *, + num_crops: Optional[int] = None): + num_images = mm_counts["image"] + + image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops) + + seq_data = dummy_seq_data_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + seq_len, + num_images, + image_token_id=_IMAGE_TOKEN_ID, + image_feature_size_override=image_feature_size, + ) + mm_data = dummy_image_for_clip( + CLIP_VIT_LARGE_PATCH14_336_CONFIG, + num_images, + image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH, + image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT, + ) + + return seq_data, mm_data + + +# Reserve this function to also handle placeholders for additional images +# [ref: PR #5820] +@lru_cache +def _get_image_placeholder_token_ids(model_config: ModelConfig, + idx: int) -> List[int]: + assert idx > 0 + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + # We need to get the token for "<", not "▁<" + # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json + a_token_id, = tokenizer.encode("a", add_special_tokens=False) + a_token_id_, *image_placeholder_token_ids = tokenizer.encode( + f"a<|image_{idx}|>", add_special_tokens=False) + assert a_token_id == a_token_id_ + + return image_placeholder_token_ids + + +def input_processor_for_phi3v(ctx: InputContext, + llm_inputs: LLMInputs, + *, + num_crops: Optional[int] = None): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + hf_config = ctx.get_hf_image_processor_config() + + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + w, h = image_data.size + image_feature_size = [ + get_phi3v_image_feature_size(hf_config, + input_width=w, + input_height=h, + num_crops=num_crops) + ] + image_data = [image_data] + elif is_list_of(image_data, Image.Image): + image_feature_size = [] + for image in image_data: + w, h = image.size + image_feature_size.append( + get_phi3v_image_feature_size(hf_config, + input_width=w, + input_height=h, + num_crops=num_crops)) + elif isinstance(image_data, torch.Tensor): + image_feature_size = [image_data.shape[0]] + image_data = [image_data] + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[0] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + prompt = llm_inputs.get("prompt") + if prompt is None: + # for async server request, we assume prompt and its token_ids is always + # in correct format. And num_image_tags == len(image_data) always True. + image_idx = range(1, len(image_data) + 1) + new_prompt = None + else: + image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) + if prompt.count("<|image|>") > 0: + logger.warning("Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating <|image|> tokens.") + elif (num_image_tags := len(image_idx)) > 1: + assert num_image_tags == len( + image_data), "The count of image_placeholder not match image's" + new_prompt = prompt + + prompt_token_ids = llm_inputs["prompt_token_ids"].copy() + + # masked place_holder with image token id + for idx in image_idx: + image_token_ids = _get_image_placeholder_token_ids(model_config, + idx=idx) + for i in range(len(prompt_token_ids) - len(image_token_ids) + 1): + if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids: + prompt_token_ids[i:i + len(image_token_ids)] = [ + _IMAGE_TOKEN_ID + ] * len(image_token_ids) + break + + # merge consecutive tag ids + merged_token_ids: List[int] = [] + for is_placeholder, token_ids in itertools.groupby( + prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID): + if is_placeholder: + merged_token_ids.append(_IMAGE_TOKEN_ID) + else: + merged_token_ids.extend(list(token_ids)) + + # TODO: Move this to utils or integrate with clip. + new_token_ids: List[int] = [] + placeholder_idx = 0 + while merged_token_ids: + token_id = merged_token_ids.pop(0) + if token_id == _IMAGE_TOKEN_ID: + new_token_ids.extend( + repeat_and_pad_token( + _IMAGE_TOKEN_ID, + repeat_count=image_feature_size[placeholder_idx], + )) + placeholder_idx += 1 + else: + new_token_ids.append(token_id) + + # NOTE: Create a defensive copy of the original inputs + llm_inputs = LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + return llm_inputs + + +@MULTIMODAL_REGISTRY.register_image_input_mapper() +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) +@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + self.image_token_id = _IMAGE_TOKEN_ID + + # TODO: Optionally initializes this for supporting embeddings. + self.vision_embed_tokens = Phi3HDImageEmbedding(config) + + self.language_model = LlamaForCausalLM(config, cache_config, + quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _validate_pixel_values( + self, data: Union[torch.Tensor, List[torch.Tensor]] + ) -> Union[torch.Tensor, List[torch.Tensor]]: + + h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size + expected_dims = (3, h, w) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape[1:]) + + if actual_dims != expected_dims: + expected_expr = ("num_patches", *map(str, expected_dims)) + raise ValueError( + "The expected shape of pixel values per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) + + return data + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Phi3VImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_sizes = kwargs.pop("image_sizes", None) + image_embeds = kwargs.pop("image_embeds", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + if not isinstance(image_sizes, (torch.Tensor, list)): + raise ValueError("Incorrect type of image sizes. " + f"Got type: {type(image_sizes)}") + + return Phi3VImagePixelInputs( + type="pixel_values", + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True))) + + if image_embeds is not None: + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + + return Phi3VImageEmbeddingInputs( + type="image_embeds", + data=flatten_bn(image_embeds), + ) + + raise AssertionError("This line should be unreachable.") + + def _process_image_input( + self, + image_input: Phi3VImageInputs, + ) -> torch.Tensor: + + if image_input["type"] == "image_embeds": + image_data = image_input["data"] + if is_list_of(image_data, torch.Tensor): + # it's already a list of tensors + return image_data + if len(image_data.shape) == 3: + # 3D tensor + return list(torch.unbind(image_data, dim=0)) + raise ValueError( + "We expect batched 2D tensors;" + "this can be either a list of 2D tensors or a single 3D tensor." + ) + + assert self.vision_embed_tokens is not None + image_embeds = self.vision_embed_tokens(image_input["data"], + image_input["image_sizes"]) + + return image_embeds + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object): + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.vision_embed_tokens.": "vision_embed_tokens.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py new file mode 100644 index 0000000..4310655 --- /dev/null +++ b/vllm/model_executor/models/phimoe.py @@ -0,0 +1,663 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only PhiMoE model.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class PhiMoEConfig(PretrainedConfig): + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class mp(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + scores: torch.Tensor, + multiplier: torch.Tensor, + selected_experts: torch.Tensor, + masked_gates: torch.Tensor, + mask_for_one: torch.Tensor, + ): + ctx.save_for_backward(multiplier, selected_experts, masked_gates) + return multiplier * mask_for_one + + @staticmethod + def backward( + ctx, + grad_at_output: torch.Tensor, + ): + multiplier, selected_experts, masked_gates = ctx.saved_tensors + + grad_at_output = grad_at_output * multiplier + + grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) + grad_at_scores_expaned.scatter_add_( + dim=-1, + index=selected_experts, + src=grad_at_output, + ) + + return ( + grad_at_scores_expaned, + None, + None, + None, + None, + ) + + +def sparsemixer(scores, jitter_eps=0.01): + ################ first expert ################ + + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, + keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, + float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, + index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), + dim=-1) + multiplier = multiplier.to(torch.float32) + selected_experts = selected_experts.to(torch.int32) + return ( + multiplier, + selected_experts, + ) + + +def phimoe_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert topk == 2, "Only top-2 routing is supported" + assert renormalize is False, "Renormalization is not supported" + + topk_weights, topk_ids = sparsemixer(gating_output) + return topk_weights, topk_ids + + +class PhiMoE(nn.Module): + """A tensor-parallel MoE implementation for PhiMoE that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + custom_routing_function=phimoe_routing_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class PhiMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[dict] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class PhiMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = PhiMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=config.rope_scaling, + ) + self.block_sparse_moe = PhiMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + + # Self Attention + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states, residual + + +class PhiMoEModel(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiMoEDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + "w1", + "w2", + "w3", + "gate", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = PhiMoEModel(config, + cache_config, + quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=None, + bias=True, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py new file mode 100644 index 0000000..c8957dc --- /dev/null +++ b/vllm/model_executor/models/pixtral.py @@ -0,0 +1,578 @@ +from dataclasses import dataclass, fields +from functools import cached_property +from itertools import tee +from typing import Iterable, List, Mapping, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mistral_common.protocol.instruct.messages import ImageChunk +from PIL import Image +from transformers import PretrainedConfig +from xformers.ops.fmha import memory_efficient_attention +from xformers.ops.fmha.attn_bias import BlockDiagonalMask + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import merge_multimodal_embeddings +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors, SequenceData + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import init_vllm_registered_model + + +def get_max_pixtral_image_tokens(ctx: InputContext): + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) + mm_encoder = tokenizer.instruct.mm_encoder + + max_image_size = mm_encoder.mm_config.max_image_size + image_patch_size = mm_encoder.mm_config.image_patch_size + + return ((max_image_size // image_patch_size)**2) + + +def dummy_data_for_pixtral(ctx: InputContext, seq_len: int, + mm_counts: Mapping[str, int]): + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) + + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + patch_size = mm_encoder.mm_config.image_patch_size + image_token_id = mm_encoder.special_ids.img + + mm_config = ctx.model_config.multimodal_config + num_images = mm_config.limit_per_prompt.get("image", 1) + + # dummy size + size = 256 + image = Image.new("RGB", (size, size), color=0) + + image_feature_size = (size**2) // (patch_size**2) + + num_image_tokens = image_feature_size * num_images + seq_data = SequenceData.from_token_counts( + (image_token_id, num_image_tokens), + (0, seq_len - num_image_tokens), + ) + + mm_data = {"image": num_images * [image]} + return seq_data, mm_data + + +def input_mapper_for_pixtral(ctx: InputContext, + data: object) -> MultiModalInputs: + """Maps the input data to its MultiModalInputs (if any). + + Args: + ctx: Context of the loaded model. + data: data potentially containing image/image embeddings to be mapped + to pixel_values in .forward() for a visual QWenLMHeadModel model. + + Returns: + MultiModalInputs containing the stacked normalized images tensor or + image embeddings. + """ + # Early exit if we have provided an image to a language only Qwen model + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, tokenizer_mode=model_config.tokenizer_mode) + + data_list = data if isinstance(data, list) else [data] + + images = [] + for image_data in data_list: + image = ImageChunk(image=image_data) + encoding = tokenizer.instruct.mm_encoder(image) + image = torch.from_numpy(encoding.image).to(device="cuda", + dtype=torch.float16) + images.append(image) + + return MultiModalInputs({"images": images}) + + +def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is not None and "image" in multi_modal_data: + tokenizer = cached_get_tokenizer( + ctx.model_config.tokenizer, + tokenizer_mode=ctx.model_config.tokenizer_mode) + + mm_encoder = tokenizer.mistral.instruct_tokenizer.mm_encoder + image_token_id = mm_encoder.special_ids.img + + if image_token_id not in llm_inputs['prompt_token_ids']: + raise ValueError( + (f"You've passed {llm_inputs=} without {image_token_id=}" + " Make sure to process your input via mistral_common's" + " tokenizer or pass a chat completion request. For more" + " For more info, see: " + "https://github.com/vllm-project/vllm/issues/8411.")) + + return llm_inputs + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_pixtral) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) +@INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + self.config = config + self.multimodal_config = multimodal_config + + dataclass_fields = {field.name for field in fields(VisionEncoderArgs)} + vision_args = { + key: value + for key, value in self.config.vision_config.to_dict().items() + if key in dataclass_fields + } + + self.vision_args = VisionEncoderArgs(**vision_args) + + # init MistralForCausalLM + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + self.vision_encoder = VisionTransformer(self.vision_args) + self.vision_language_adapter = VisionLanguageAdapter( + self.vision_args, dim=config.text_config.hidden_size) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for pixtral. + + TODO + + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_args.image_token_id) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) + + return hidden_states + + def _parse_and_validate_image_input( + self, + images: Optional[Union[List[List[torch.Tensor]], List[torch.Tensor], + torch.Tensor]] = None + ) -> Optional[List[torch.Tensor]]: + if images is None: + return None + + if isinstance(images, torch.Tensor): + # if passed as batch take all images + N, B, C, W, H = images.shape + images = images.reshape(N * B, C, W, H) + images = [images[i] for i in range(images.size(0))] + elif isinstance(images, list): + # if passed as list flatten lists of tensors + flatten_images = [] + for imgs_per_req in images: + imgs_per_req = [ + imgs_per_req[i] for i in range(imgs_per_req.size(0)) + ] if isinstance(imgs_per_req, torch.Tensor) else imgs_per_req + + flatten_images.extend(imgs_per_req) + + images = flatten_images + + return images + + def _process_image_input(self, + image_input: List[torch.Tensor]) -> torch.Tensor: + return self.vision_language_adapter(self.vision_encoder(image_input)) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + def is_vision_encoder_weights(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("vision_encoder") + + def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): + return weight[0].startswith("vision_language_adapter") + + def is_vision_weights(weight: Tuple[str, torch.Tensor]): + return is_vision_encoder_weights( + weight) or is_vision_lang_adapter_weights(weight) + + llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee( + weights, 3) + + # llm + llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights) + self.language_model.load_weights(llm_weights) + + # vision encoder + vision_encoder_weights = filter(is_vision_encoder_weights, + vision_encoder_weights) + vision_encoder_dict = dict(self.vision_encoder.named_parameters()) + for name, loaded_weight in vision_encoder_weights: + # cut 'vision_encoder.' + name = '.'.join(name.split(".")[1:]) + param = vision_encoder_dict[name] + + default_weight_loader(param, loaded_weight) + + # adapter + vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights, + vision_lang_adapter_weights) + vision_lang_adpter_dict = dict( + self.vision_language_adapter.named_parameters()) + for name, loaded_weight in vision_lang_adapter_weights: + # cut 'vision_language_adapter.' + name = '.'.join(name.split(".")[1:]) + param = vision_lang_adpter_dict[name] + default_weight_loader(param, loaded_weight) + + +# Vision encoder +@dataclass +class VisionEncoderArgs: + hidden_size: int + num_channels: int + image_size: int + patch_size: int + intermediate_size: int + num_hidden_layers: int + num_attention_heads: int + rope_theta: float # for rope-2D + image_token_id: int + + +def _reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor) -> torch.Tensor: + """ + freqs_cis: complex - (seq_len, head_dim / 2) + x: complex - (bsz, seq_len, head_dim / 2) + """ + ndim = x.ndim + assert ndim > 1 + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), ( + freqs_cis.shape, + (x.shape[1], x.shape[-1]), + ) + shape = [ + d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape) + ] + return freqs_cis.view(*shape) + + +def precompute_freqs_cis_2d( + dim: int, + height: int, + width: int, + theta: float, +) -> torch.Tensor: + """ + freqs_cis: 2D complex tensor of shape (height, width, dim // 2) + to be indexed by (height, width) position tuples + """ + # (dim / 2) frequency bases + freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) + + h = torch.arange(height, device=freqs.device) + w = torch.arange(width, device=freqs.device) + + freqs_h = torch.outer(h, freqs[::2]).float() + freqs_w = torch.outer(w, freqs[1::2]).float() + freqs_2d = torch.cat( + [ + freqs_h[:, None, :].repeat(1, width, 1), + freqs_w[None, :, :].repeat(height, 1, 1), + ], + dim=-1, + ) + return torch.polar(torch.ones_like(freqs_2d), freqs_2d) + + +def apply_rotary_emb_vit( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + assert freqs_cis.dtype == torch.complex64 + freqs_cis = _reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +class FeedForward(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + assert args.intermediate_size is not None + self.w1 = nn.Linear(args.hidden_size, + args.intermediate_size, + bias=False) + self.w2 = nn.Linear(args.intermediate_size, + args.hidden_size, + bias=False) + self.w3 = nn.Linear(args.hidden_size, + args.intermediate_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class Attention(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + assert not args.hidden_size % args.num_attention_heads + self.n_heads = args.num_attention_heads + self.head_dim = args.hidden_size // args.num_attention_heads + + self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + batch, patches, _ = x.shape + + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = q.reshape(batch, patches, self.n_heads, self.head_dim) + k = k.reshape(batch, patches, self.n_heads, self.head_dim) + v = v.reshape(batch, patches, self.n_heads, self.head_dim) + + q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) + out = memory_efficient_attention(q, k, v, attn_bias=mask) + out = out.reshape(batch, patches, self.n_heads * self.head_dim) + return self.wo(out) + + +class TransformerBlock(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.attention = Attention(args) + self.feed_forward = FeedForward(args) + self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) + self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: torch.Tensor, + ) -> torch.Tensor: + r = self.attention.forward(self.attention_norm(x), + mask=mask, + freqs_cis=freqs_cis) + h = x + r + r = self.feed_forward.forward(self.ffn_norm(h)) + out = h + r + return out + + +class Transformer(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.layers = torch.nn.ModuleList() + for _ in range(args.num_hidden_layers): + self.layers.append(TransformerBlock(args)) + + def forward( + self, + x: torch.Tensor, + mask: BlockDiagonalMask, + freqs_cis: Optional[torch.Tensor], + ) -> torch.Tensor: + for layer in self.layers: + x = layer(x, mask=mask, freqs_cis=freqs_cis) + return x + + +def position_meshgrid(patch_embeds_list: List[torch.Tensor], ) -> torch.Tensor: + positions = torch.cat([ + torch.stack( + torch.meshgrid( + torch.arange(p.shape[-2]), + torch.arange(p.shape[-1]), + indexing="ij", + ), + dim=-1, + ).reshape(-1, 2) for p in patch_embeds_list + ]) + return positions + + +class VisionTransformer(nn.Module): + + def __init__(self, args: VisionEncoderArgs): + super().__init__() + self.args = args + self.patch_conv = nn.Conv2d( + in_channels=args.num_channels, + out_channels=args.hidden_size, + kernel_size=args.patch_size, + stride=args.patch_size, + bias=False, + ) + self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) + self.transformer = Transformer(args) + + head_dim = self.args.hidden_size // self.args.num_attention_heads + assert head_dim % 2 == 0, "ROPE requires even head_dim" + self._freqs_cis: Optional[torch.Tensor] = None + + @property + def max_patches_per_side(self) -> int: + return self.args.image_size // self.args.patch_size + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.device: + return next(self.parameters()).dtype + + @property + def freqs_cis(self) -> torch.Tensor: + if self._freqs_cis is None: + self._freqs_cis = precompute_freqs_cis_2d( + dim=self.args.hidden_size // self.args.num_attention_heads, + height=self.max_patches_per_side, + width=self.max_patches_per_side, + theta=self.args.rope_theta, + ) + + if self._freqs_cis.device != self.device: + self._freqs_cis = self._freqs_cis.to(device=self.device) + + return self._freqs_cis + + def forward( + self, + images: List[torch.Tensor], + ) -> torch.Tensor: + """ + Args: + images: list of N_img images of variable sizes, + each of shape (C, H, W) + Returns: + image_features: tensor of token features for + all tokens of all images of shape (N_toks, D) + """ + # pass images through initial convolution independently + patch_embeds_list = [ + self.patch_conv(img.unsqueeze(0).to(self.dtype)) for img in images + ] + + # flatten to a single sequence + patch_embeds = torch.cat( + [p.flatten(2).permute(0, 2, 1) for p in patch_embeds_list], dim=1) + patch_embeds = self.ln_pre(patch_embeds) + + # positional embeddings + positions = position_meshgrid(patch_embeds_list).to(self.device) + freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]] + + # pass through Transformer with a block diagonal mask delimiting images + mask = BlockDiagonalMask.from_seqlens( + [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], ) + out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis) + + # remove batch dimension of the single sequence + return out.squeeze(0) + + +class VisionLanguageAdapter(nn.Module): + + def __init__(self, args: VisionEncoderArgs, dim: int): + super().__init__() + assert isinstance(args, VisionEncoderArgs) + self.w_in = nn.Linear( + args.hidden_size, + dim, + bias=True, + ) + self.gelu = nn.GELU() + self.w_out = nn.Linear(dim, dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w_out(self.gelu(self.w_in(x))) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py new file mode 100644 index 0000000..fd8a27e --- /dev/null +++ b/vllm/model_executor/models/qwen.py @@ -0,0 +1,991 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE +"""Inference-only QWen model compatible with HuggingFace weights.""" + +import math +import re +from functools import partial +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, TypedDict, Union) + +import numpy as np +import torch +from PIL import Image +from torch import nn +from torchvision import transforms +from torchvision.transforms import InterpolationMode +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.utils import is_list_of + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + +logger = init_logger(__name__) + +# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; +# for the time being, these tags are not considered as special at encoding +# time. This may change as VLLMs multimodal API changes in the future. +IMG_START = "" +IMG_END = "" +IMG_PAD = "" +# Image context is fixed at 256 for all images +MAX_QWEN_IMG_TOKENS = 256 +# Image normalization params +CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + + +class QwenImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, 3, image_size, image_size)` + + Note that image_size is the value in the vision config to which we resize + the image to in the normalization transform. Currently multi-image support + can only be leveraged by passing image embeddings directly. + """ + + +class QwenImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, 256, hidden_size)` + + `hidden_size` must match the hidden size of the language model backbone + and is stored in the visual config of the model if we have one. + """ + + +QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + + +class VisualAttention(nn.Module): + """self-attention layer class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim \ + and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, \ + 'Visual Attention implementation only supports self-attention' + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # query/key/value: [sq, b, h] + sq, b, _ = x.size() + mixed_x_layer = self.in_proj(x) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, + key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, + self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + + return output + + +class QwenVMLP(nn.Module): + """MLP for the visual component of the Qwen model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.c_fc = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config) + self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + x, _ = self.c_fc(x) + x = self.act_fn(x) + x, _ = self.c_proj(x) + return x + + +class VisualAttentionBlock(nn.Module): + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = QwenVMLP( + hidden_size=d_model, + intermediate_size=mlp_width, + quant_config=quant_config, + ) + + def attention( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, attn_mask=attn_mask) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock(width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, + image_width // patch_width) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * + torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock(width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + + self.attn_pool = Resampler2( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + adaptive=False, + do_post_projection=False, + ).to( + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, + ) + + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter( + (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + self.image_start_id = image_start_id + self.image_end_id = image_start_id + 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( + x.size(1)))) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + def get_image_positions(self, + input_ids: torch.Tensor) -> Optional[torch.Tensor]: + """Given the input IDs, extracts start/stop points corresponding to + images. + + args: + Returns: + Optional torch tensor corresponding to start/stop pairs of images. + """ + if torch.any(input_ids == self.image_start_id): + bos_pos = torch.where(input_ids == self.image_start_id) + eos_pos = torch.where(input_ids == self.image_end_id) + return torch.stack((bos_pos[0], eos_pos[0]), dim=1) + return None + + +class QWenMLP(nn.Module): + """MLP for the language component of the Qwen model, which contains a + MergedColumnParallelLinear merging 2 outputs via silu activation.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.c_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class QWenAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.c_attn = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + self.scaling = self.head_dim**-0.5 + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.c_proj(attn_output) + return output + + +class QWenBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + self.attn = QWenAttention(config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + cache_config=cache_config, + quant_config=quant_config) + + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = QWenMLP(config.hidden_size, + config.intermediate_size // 2, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class QWenModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: QWenBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h") + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + self.visual = VisionTransformer(**config.visual, + quant_config=quant_config) if hasattr( + config, "visual") else None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + pixel_values: Optional[QwenImageInputs], + ) -> Union[torch.Tensor, IntermediateTensors]: + img_pos = None + # If pixel / visual embeddings are provided, this is a visual model + if pixel_values is not None and self.visual is not None: + if pixel_values["type"] != "image_embeds": + image_embeds = self.visual(pixel_values["data"]) + else: + image_embeds = pixel_values["data"] + + # features should be of shape (# images, 256, hidden_dim) + img_pos = self.visual.get_image_positions(input_ids) + if isinstance( + img_pos, + np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]: + raise ValueError( + f"Number of placeholders: {img_pos.shape[0]} " + f"does not match number of images {image_embeds.shape[0]}." + ) + + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + # Merge the image embeddings into the hidden states if actually have + # visual features and the corresponding image tokens + if img_pos is not None: + for idx, (img_bos, img_eos) in enumerate(img_pos): + hidden_states[img_bos + 1:img_eos] = image_embeds[idx] + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +def get_image_text(image_num: int, padding: bool) -> str: + """Retrieves a placeholder text that when tokenized, will be expanded with + image pads. + + Args: + image_num: The number of the image that we want a text prompt for. + Images should be indexed starting at 1. + padding: Whether or not padding should be manually added. + + Returns: + Text placeholder prompt for the image being considered. + """ + image_start = f"Picture {image_num}: {IMG_START}" + image_end = f"{IMG_END}\n" + if not padding: + return f"{image_start}{image_end}" + return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}" + + +def input_processor_for_qwen(ctx: InputContext, + llm_inputs: LLMInputs) -> LLMInputs: + """Processes the inputs, which may or may not be multimodal. + Multimodal inputs will only be processed if the model has a "visual" + component in its model config, otherwise they'll be ignored. + + Args: + ctx: Context of the loaded model. + llm_inputs: LLM inputs which may have a multi_modal_data attribute. + + Returns: + If the model is language only or not multimodal inputs were provided, + returns llm_inputs unmodified. Otherwise, processes the multimodal + images / image embeddings and adds the fixed-length image placeholders. + """ + multi_modal_data = llm_inputs.get("multi_modal_data") + + # Only process images if we have multimodal data and a visual config + hf_config = ctx.get_hf_config() + if (multi_modal_data is None or "image" not in multi_modal_data + or not hasattr(hf_config, "visual")): + return llm_inputs + + prompt = llm_inputs.get("prompt") + prompt_token_ids = llm_inputs["prompt_token_ids"] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + image_data = multi_modal_data["image"] + if isinstance(image_data, torch.Tensor): + num_dims = len(image_data.shape) + if num_dims < 2 or num_dims > 3: + raise ValueError( + f"Expected img embeds to be have 3 dimensions, got {num_dims}") + num_images = 1 if num_dims == 2 else image_data.shape[0] + elif isinstance(image_data, Image.Image): + num_images = 1 + elif is_list_of(image_data, Image.Image): + num_images = len(image_data) + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + # Drops anything between / tags; encoding with the tokenizer + # will automatically add the image pads for the context. + new_prompt, num_matched_images = re.subn( + r"(Picture \d*: ).*?(<\/img>\n)", + r"\1\2", + prompt, + ) + + if num_matched_images != num_images: + logger.warning( + "Number of matched image placeholders %s doesn't match the number " + "of expected images %s; check your placeholder formatting.", + num_matched_images, num_images) + + new_prompt_token_ids = tokenizer.encode(new_prompt) + + return LLMInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) + + +def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: + """Maps the input data to its MultiModalInputs (if any). + + Args: + ctx: Context of the loaded model. + data: data potentially containing image/image embeddings to be mapped + to pixel_values in .forward() for a visual QWenLMHeadModel model. + + Returns: + MultiModalInputs containing the stacked normalized images tensor or + image embeddings. + """ + # Early exit if we have provided an image to a language only Qwen model + hf_config = ctx.get_hf_config() + if not hasattr(hf_config, "visual"): + logger.warning( + "Images were provided but this model has no visual config; " + "multimodal inputs will not be forwarded to the model.") + return MultiModalInputs() + + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + + image_pair_tok = tokenizer.encode(IMG_START + IMG_END, + add_special_tokens=False, + return_tensors="pt").squeeze() + image_start_id = image_pair_tok[0] + image_end_id = image_pair_tok[-1] + if (image_start_id + 1) != image_end_id: + raise ValueError( + f"Found image end ID {image_end_id}, but expected {IMG_START} + 1") + if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2): + raise ValueError( + f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, " + f"but got {image_pair_tok - 2}") + + hf_config = ctx.get_hf_config() + image_size = hf_config.visual["image_size"] + img_emb_size = hf_config.visual["output_dim"] + + if isinstance(data, torch.Tensor): + # It's expected that our values have already been processed + # by the visual transformer; shape is expected to be: + # (# images, 256, hidden_size) + if len(data.shape) == 2: + # Assume only one image embed was provided; unsqueeze the extra dim + data = data.unsqueeze(0) + if len(data.shape) != 3 or data.shape[ + 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: + raise ValueError( + "Expected image embeds to be a tensor of shape" + f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " + f"received shape [{data.shape}]") + pixel_values = data + else: + transform = build_normalization_transform(image_size) + if not isinstance(data, (list, tuple)): + data = [data] + transformed_images = [transform(datum) for datum in data] + pixel_values = torch.stack(transformed_images, dim=0) + return MultiModalInputs({"pixel_values": pixel_values}) + + +def build_normalization_transform(image_size: int) -> transforms.Compose: + """Builds a normalization transform which can be applied to one or + more input images from which we want to extract visual features. + + Args: + image_size: size of the image to be processed for visual embeddings. + + Returns: + Callable transform for normalizing and resizing one RGB image. + """ + return transforms.Compose([ + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + + +def dummy_data_for_qwen( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +) -> Tuple[SequenceData, Optional[Dict]]: + """Build dummy data for warming up Qwen models; this will only contain text + matching the defaults for VLLM unless the model has a visual config. + + Args: + ctx: Context of the loaded model. + seq_len: Number of tokens in the text sequence. + mm_counts: multimodal data counts. + + Returns: + Tuple containing sequential and multimodal data. + """ + hf_config = ctx.get_hf_config() + + # The presence of a visual config indicates this is a multimodal model. + # If we don't have it, the model is considered an LLM for warmup purposes. + if not hasattr(hf_config, "visual"): + seq_data = SequenceData.from_token_counts((0, seq_len)) + mm_data = None + return seq_data, mm_data + + # We have a visual component - use images to warm up + num_images = mm_counts["image"] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer( + model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code) + + # Build the image prompts with no imgpads; the tokenizer will add img pads + image_prompt = ''.join( + [get_image_text(idx, False) for idx in range(1, num_images + 1)]) + toks = tokenizer.encode(image_prompt, add_special_tokens=False) + + # Make sure we actually get the fixed context size per tok padding + num_pads = toks.count(tokenizer.encode(IMG_PAD)[0]) + if num_pads != (num_images * MAX_QWEN_IMG_TOKENS): + raise ValueError( + f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads" + f" per image, but got {num_pads} pads for {num_images} image(s)" + " in total. Are you using a qwen tokenizer?") + + # Ensure the number of tokens is at minimum the sequence length provided + if len(toks) < seq_len: + toks += [0] * (seq_len - len(toks)) + + seq_data = SequenceData.from_seqs(toks) + + # Build the input images; width/height doesn't actually matter here since + # the data will get resized and the # of tokens per image is constant + image = Image.new("RGB", (224, 224), color=0) + mm_data = {"image": image if num_images == 1 else [image] * num_images} + return seq_data, mm_data + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__( + self, + config: PretrainedConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.multimodal_config = multimodal_config + self.quant_config = quant_config + self.transformer = QWenModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + + def _get_image_input_type( + self, + pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: + """Determines if the provided pixel_values are normalized pixel values + or image embeddings. + + Args: + pixel_values: Optional data to processed into visual embeddings. + + Returns: + None of the QwenImageInputs type used to determine whether or not + the visual transformer needs to process the pixel_values. + """ + if pixel_values is not None and self.transformer.visual is not None: + pixel_values = flatten_bn(pixel_values) + if len(pixel_values.shape) == 3 and pixel_values.shape[ + 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[ + 2] == self.config.visual["output_dim"]: + return QwenImageEmbeddingInputs( + type="image_embeds", + data=pixel_values, + ) + else: + # If we have the wrong shape, assume we still need to process + return QwenImagePixelInputs( + type="pixel_values", + data=pixel_values, + ) + return None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + pixel_values = None + else: + pixel_values = self._get_image_input_type(pixel_values) + + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + pixel_values) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w2", 0), + ("gate_up_proj", "w1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py new file mode 100644 index 0000000..08ec53d --- /dev/null +++ b/vllm/model_executor/models/qwen2.py @@ -0,0 +1,457 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class Qwen2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2Model(nn.Module): + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + decoder_layer_type: type[nn.Module] = Qwen2DecoderLayer + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer_type(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + + # from ixformer.inference.overlap import create_vllm_llama_decoder_layer + from vllm.distributed import get_tp_group + + # for layer_idx in range(self.start_layer, self.end_layer): + # self.layers[layer_idx] = create_vllm_llama_decoder_layer( + # self.layers[layer_idx], + # model_id = id(self), + # layer_idx = layer_idx, + # enable_overlap = True, + # group = get_tp_group(), + # quant_config=quant_config, + # activation_dtype=torch.get_default_dtype() + # ) + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py new file mode 100644 index 0000000..d4475b7 --- /dev/null +++ b/vllm/model_executor/models/qwen2_moe.py @@ -0,0 +1,523 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import print_warning_once + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class Qwen2MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2MoeSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config) + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None) + if config.shared_expert_intermediate_size > 0: + self.shared_expert = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.shared_expert_intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + else: + self.shared_expert = None + self.shared_expert_gate = torch.nn.Linear(config.hidden_size, + 1, + bias=False) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + if self.shared_expert is not None: + shared_output = self.shared_expert(hidden_states) + if self.shared_expert_gate is not None: + shared_output = F.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_output + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen2MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen2MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + ) + + # Note: Qwen/Qwen2-57B-A14B-Instruct does not have + # `mlp_only_layers` in the config. + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen2MoeSparseMoeBlock(config=config, + quant_config=quant_config) + else: + self.mlp = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2MoeModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2MoeDecoderLayer(config=config, + layer_idx=int( + prefix.split(".")[-1]), + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen2MoeForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Qwen2MoeModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + print_warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py new file mode 100644 index 0000000..2edcb94 --- /dev/null +++ b/vllm/model_executor/models/qwen2_rm.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +"""Inference-only Qwen2-RM model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .interfaces import SupportsPP +from .qwen2 import Qwen2Model +from .utils import AutoWeightsLoader + + +class ReLU(nn.Module): + + def __init__(self): + super().__init__() + self.activation = nn.ReLU() + + def forward(self, input): + input, _ = input + return self.activation(input) + + +class Qwen2ForRewardModel(nn.Module, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) + + self.score = nn.Sequential( + ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config), + ReLU(), + RowParallelLinear(config.hidden_size, 1, + quant_config=quant_config), + ) + self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False) + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + logits, _ = self.score(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self,ignore_unexpected_prefixes=["lm_head"]) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py new file mode 100644 index 0000000..4a39b3f --- /dev/null +++ b/vllm/model_executor/models/qwen2_vl.py @@ -0,0 +1,1174 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/19e6e80e10118f855137b90740936c0b11ac397f/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +from functools import lru_cache, partial +from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional, + Tuple, Type, TypedDict, Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat +from PIL import Image +from transformers.image_utils import (get_image_size, + infer_channel_dimension_format, + to_numpy_array) +from transformers.models.qwen2_vl.image_processing_qwen2_vl import ( + make_batched_images, make_batched_videos, smart_resize) + +import vllm.envs as envs +from vllm.attention import AttentionMetadata +from vllm.attention.selector import (_Backend, backend_name_to_enum, + get_global_forced_attn_backend) +from vllm.config import CacheConfig, MultiModalConfig +from vllm.distributed import get_pp_group, parallel_state +from vllm.distributed import utils as dist_utils +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import QuickGELU +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalInputs) +from vllm.multimodal.base import MultiModalData +from vllm.multimodal.image import cached_get_image_processor +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors, SequenceData +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) +from vllm.transformers_utils.processor import get_processor +from vllm.utils import is_cpu + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory) + +logger = init_logger(__name__) + +# === Vision Inputs === # + + +class Qwen2VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + `hidden_size` must match the hidden size of language model backbone. + """ + + +Qwen2VLImageInputs = Union[Qwen2VLImagePixelInputs, + Qwen2VLImageEmbeddingInputs] + + +class Qwen2VLVideoInputs(TypedDict): + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +# === Vision Encoder === # + + +class Qwen2VisionMLP(nn.Module): + + def __init__( + self, + in_features: int, + hidden_features: int = None, + act_layer: Type[nn.Module] = QuickGELU, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.fc1 = ColumnParallelLinear(in_features, + hidden_features, + quant_config=quant_config) + self.act = act_layer() + self.fc2 = RowParallelLinear(hidden_features, + in_features, + quant_config=quant_config) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_parallel, _ = self.fc1(x) + x_parallel = self.act(x_parallel) + x, _ = self.fc2(x_parallel) + return x + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), + "... d two -> ... (d two)", + two=2) + + +def apply_rotary_emb_torch(x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat( + sin, + "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [ + x[..., :ro_dim] * cos + + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:] + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, + freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + return output + + +class Qwen2VisionAttention(nn.Module): + + def __init__( + self, + embed_dim: Optional[int] = None, + num_heads: Optional[int] = None, + projection_size: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size) + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config) + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config) + + # Detect attention implementation. + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() + if selected_backend is None: + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = backend_name_to_enum(backend_by_env_var) + if selected_backend is None: + # For Volta and Turing GPUs, use xformers instead. + device_available = current_platform.has_device_capability(80) + if device_available: + from transformers.utils import is_flash_attn_2_available + + if is_flash_attn_2_available(): + self._use_flash_attn = True + else: + logger.warning( + "Current Qwen2-VL implementation has a bug with " + "`vllm-flash-attn` inside vision module, so we use " + "xformers backend instead. You can run `pip install " + "flash-attn to use flash-attention backend.") + self._use_flash_attn = False + else: + self._use_flash_attn = False + else: + if selected_backend == _Backend.FLASH_ATTN: + self._use_flash_attn = True + elif selected_backend == _Backend.XFORMERS: + self._use_flash_attn = False + else: + raise RuntimeError( + f"Qwen2-VL does not support {selected_backend} backend now." + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + x = x.view(*new_x_shape) + + # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + batch_size = q.shape[1] + + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self._use_flash_attn: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif is_cpu(): + seq_length = q.size(1) + q, k, v = [rearrange(x, "b s h d -> b h s d") for x in [q, k, v]] + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + context_layer = rearrange(output, "b h s d -> b s h d ") + else: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Qwen2VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float, + act_layer: Type[nn.Module] = QuickGELU, + norm_layer: Type[nn.Module] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.attn = Qwen2VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config) + self.mlp = Qwen2VisionMLP(dim, + mlp_hidden_dim, + act_layer=act_layer, + quant_config=quant_config) + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_chans: int = 3, + embed_dim: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d(in_chans, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.embed_dim) + return x + + +class Qwen2VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Type[nn.Module] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) + / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + patch_size: int = vision_config.patch_size + temporal_patch_size: int = vision_config.temporal_patch_size + spatial_merge_size: int = vision_config.spatial_merge_size + in_chans: int = vision_config.in_chans + hidden_size: int = vision_config.hidden_size + embed_dim: int = vision_config.embed_dim + depth: int = vision_config.depth + num_heads: int = vision_config.num_heads + mlp_ratio: float = vision_config.mlp_ratio + + self.spatial_merge_size = spatial_merge_size + + self.patch_embed = Qwen2VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + norm_layer = partial(nn.LayerNorm, eps=norm_eps) + head_dim = embed_dim // num_heads + self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen2VisionBlock( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config, + ) for _ in range(depth) + ]) + self.merger = Qwen2VisionPatchMerger( + d_model=hidden_size, + context_dim=embed_dim, + norm_layer=norm_layer, + quant_config=quant_config, + ) + + @property + def dtype(self) -> torch.dtype: + return self.blocks[0].mlp.fc2.weight.dtype + + @property + def device(self) -> torch.device: + return self.blocks[0].mlp.fc2.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + x = x.unsqueeze(1) + for blk in self.blocks: + x = blk(x, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + # adapter + x = self.merger(x) + return x + + +# === Vision input helpers === # + +cached_get_processor = lru_cache(get_processor) + + +def mm_input_mapper_for_qwen2_vl( + ctx: InputContext, + data: MultiModalData[object], + data_type_key: str, +) -> MultiModalInputs: + """Input mapper for Qwen2-VL.""" + if data_type_key == "image" and isinstance(data, dict): + return MultiModalInputs({ + "image_embeds": data.get("image_embeds"), + "image_grid_thw": data.get("image_grid_thw"), + }) + model_config = ctx.model_config + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + + images = None + videos = None + if data_type_key == "image": + images = data + else: + assert data_type_key == "video" + videos = data + + try: + batch_data = image_processor \ + .preprocess(images=images, videos=videos, return_tensors="pt") \ + .data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + + return MultiModalInputs(batch_data) + + +image_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, + data_type_key="image") +video_input_mapper_for_qwen2_vl = partial(mm_input_mapper_for_qwen2_vl, + data_type_key="video") + + +def _get_vision_info( + image_processor, + height: int, + width: int, + min_pixels: int, + max_pixels: int, + do_resize: bool = True, + data_type_key: str = "image", + mm_count: int = 1, +): + """Get information (resized height / width and number of vision tokens) + of input image / video frame.""" + + if do_resize: + resized_height, resized_width = smart_resize( + height=height, + width=width, + factor=image_processor.patch_size * image_processor.merge_size, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + else: + resized_height, resized_width = height, width + + if data_type_key == "image": + grid_t = mm_count + else: + assert data_type_key == "video" + grid_t = max(mm_count // image_processor.temporal_patch_size, 1) + + grid_h = resized_height // image_processor.patch_size + grid_w = resized_width // image_processor.patch_size + vision_tokens = grid_t * grid_h * grid_w + llm_num_vision_tokens = (vision_tokens // image_processor.merge_size // + image_processor.merge_size) + + return resized_height, resized_width, llm_num_vision_tokens + + +def _get_max_image_info( + image_processor, + data_type_key: str = "image", + mm_count: int = 1, +): + return _get_vision_info( + image_processor, + height=9999999, + width=9999999, + + # Limit min / max pixels. + min_pixels=max(image_processor.min_pixels, 28 * 28), + max_pixels=min(image_processor.max_pixels, 1280 * 28 * 28), + data_type_key=data_type_key, + mm_count=mm_count, + ) + + +def get_max_qwen2_vl_mm_tokens(ctx: InputContext, data_type_key: str) -> int: + image_processor = cached_get_image_processor(ctx.model_config.model) + max_resized_height, max_resized_width, max_llm_image_tokens = \ + _get_max_image_info(image_processor, data_type_key=data_type_key, + mm_count=1) + return max_llm_image_tokens + + +get_max_qwen2_vl_image_tokens = partial(get_max_qwen2_vl_mm_tokens, + data_type_key="image") +get_max_qwen2_vl_video_tokens = partial(get_max_qwen2_vl_mm_tokens, + data_type_key="video") + + +def dummy_data_for_qwen2_vl( + ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int] +) -> Tuple[SequenceData, Optional[MultiModalDataDict]]: + image_processor = cached_get_image_processor(ctx.model_config.model) + + num_images = mm_counts["image"] + max_resized_height, max_resized_width, max_llm_image_tokens = \ + _get_max_image_info(image_processor, data_type_key="image", + mm_count=num_images) + if seq_len - max_llm_image_tokens - 2 < 0: + raise RuntimeError( + f"Qwen2-VL cannot process {num_images} images in a prompt, " + "please increase max_model_len or reduce image limit by " + "--limit-mm-per-prompt.") + + # Check video counts. + num_videos = mm_counts["video"] + max_resized_height, max_resized_width, max_llm_video_tokens = \ + _get_max_image_info(image_processor, data_type_key="video", + mm_count=num_videos) + if seq_len - max_llm_video_tokens - 2 < 0: + raise RuntimeError( + f"Qwen2-VL cannot process {num_images} videos in a prompt, " + "please increase max_model_len or reduce video limit by " + "--limit-mm-per-prompt.") + + hf_config = ctx.get_hf_config(Qwen2VLConfig) + + dummy_seqdata = SequenceData.from_token_counts( + (hf_config.vision_start_token_id, 1), + (hf_config.image_token_id, max_llm_image_tokens), + (hf_config.vision_end_token_id, 1), + (0, seq_len - max_llm_image_tokens - 2), + ) + + dummy_image = Image.new("RGB", (max_resized_width, max_resized_height), + color=0) + + return dummy_seqdata, { + "image": dummy_image if num_images == 1 else [dummy_image] * num_images + } + + +def _get_llm_num_vision_tokens( + mm_inputs: list, + data_type_key: str, + image_processor, +): + """Get number of vision tokens of multimodal inputs. + + This method is derived from `transformers.models.qwen2_vl. + image_processing_qwen2_vl.Qwen2VLImageProcessor._preprocess`. + """ + image = to_numpy_array(mm_inputs[0]) + input_data_format = infer_channel_dimension_format(image) + height, width = get_image_size(image, channel_dim=input_data_format) + _, _, llm_num_vision_tokens = _get_vision_info( + image_processor, + height=height, + width=width, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + do_resize=image_processor.do_resize, + data_type_key=data_type_key, + mm_count=len(mm_inputs), + ) + return llm_num_vision_tokens + + +def _expand_pad_tokens(inputs: list, token_id: int, make_batched_fn: Callable, + data_type_key: str, image_processor: Any, + prompt_token_ids: List[int]) -> List[int]: + """ + Expand pad tokens for multi-modal inputs (e.g., images or videos). + + Args: + inputs (list): The multi-modal inputs (e.g., images or videos). + token_id (int): The token ID used to represent the multi-modal input. + make_batched_fn (Callable): A function to batch the inputs. + data_type_key (str): The type of the multi-modal input. + image_processor (Any): The image processor used to process the inputs. + prompt_token_ids (List[int]): The list of token IDs in the prompt. + + Returns: + List[int]: The list of token IDs for the multi-modal inputs. + """ + indices = [ + idx for idx, token in enumerate(prompt_token_ids) if token == token_id + ] + inputs = make_batched_fn(inputs) + assert len(indices) == len(inputs) + + prompt_token_ids_with_data = [] + for cnt, data in enumerate(inputs): + num_tokens = _get_llm_num_vision_tokens( + [data] if data_type_key == "image" else data, + data_type_key=data_type_key, + image_processor=image_processor, + ) + if cnt == 0: + end_idx = indices[cnt] + non_data_tokens = prompt_token_ids[:end_idx] + else: + non_data_tokens = prompt_token_ids[indices[cnt - 1] + + 1:indices[cnt]] + prompt_token_ids_with_data.extend(non_data_tokens) + prompt_token_ids_with_data.extend(token_id for _ in range(num_tokens)) + prompt_token_ids_with_data.extend(prompt_token_ids[indices[-1] + 1:]) + return prompt_token_ids_with_data + + +def input_processor_for_qwen2_vl(ctx: InputContext, + llm_inputs: LLMInputs) -> LLMInputs: + multi_modal_data = llm_inputs.get("multi_modal_data", None) + if multi_modal_data is None: + return llm_inputs + + image_inputs = multi_modal_data.get("image", None) + video_inputs = multi_modal_data.get("video", None) + + processor = cached_get_processor(ctx.model_config.model) + image_processor = processor.image_processor + hf_config = ctx.get_hf_config(Qwen2VLConfig) + + # To avoid redundant processing of vision objects (resize, rescale, etc.), + # we extract code of calculating number of vision tokens from + # `transformers.models.qwen2_vl.processing_qwen2_vl.Qwen2VLProcessor`. + # + # The following code is equivalent to: + # prompt = llm_inputs["prompt"] + # inputs = processor(text=[prompt], + # images=image_inputs, + # videos=video_inputs, + # padding=True, + # return_tensors="pt") + # prompt_token_ids = inputs["input_ids"][0].tolist() + + prompt_token_ids = llm_inputs.get("prompt_token_ids", None) + if prompt_token_ids is None: + prompt = llm_inputs["prompt"] + prompt_token_ids = processor.tokenizer( + prompt, + padding=True, + return_tensors=None, + )["input_ids"] + + # Expand image pad tokens. + + if image_inputs is not None: + if isinstance(image_inputs, dict): + prompt_token_ids_with_image = [] + image_indices = [ + idx for idx, token in enumerate(prompt_token_ids) + if token == hf_config.image_token_id + ] + image_cnt = len(image_indices) + embed_dim = image_inputs.get('image_embeds').size(0) + assert embed_dim % image_cnt == 0 + num_pad_tokens = embed_dim // image_cnt + for idx, token in enumerate(prompt_token_ids): + if idx in image_indices: + prompt_token_ids_with_image.extend([token] * + num_pad_tokens) + else: + prompt_token_ids_with_image.append(token) + prompt_token_ids = prompt_token_ids_with_image + else: + prompt_token_ids = _expand_pad_tokens(image_inputs, + hf_config.image_token_id, + make_batched_images, "image", + image_processor, + prompt_token_ids) + + if video_inputs is not None: + prompt_token_ids = _expand_pad_tokens(video_inputs, + hf_config.video_token_id, + make_batched_videos, "video", + image_processor, + prompt_token_ids) + + return LLMInputs( + prompt_token_ids=prompt_token_ids, + prompt=llm_inputs["prompt"], + multi_modal_data=multi_modal_data, + ) + + +@MULTIMODAL_REGISTRY.register_image_input_mapper( + image_input_mapper_for_qwen2_vl) +@MULTIMODAL_REGISTRY.register_input_mapper("video", + video_input_mapper_for_qwen2_vl) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_qwen2_vl_image_tokens) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "video", get_max_qwen2_vl_video_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): + + def __init__(self, + config: Qwen2VLConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + + assert not cache_config.enable_prefix_caching, \ + "Qwen2-VL currently does not support prefix caching" + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + + # NOTE: Qwen2-VL vision encoder does not support any + # quantization method now. + quant_config=None, + ) + + self.model = Qwen2Model(config, cache_config, quant_config) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def _validate_and_reshape_mm_tensor(self, + mm_input: Union[torch.Tensor, + List[torch.Tensor]], + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim}") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2VLImagePixelInputs(type="pixel_values", + data=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2VLImageEmbeddingInputs(type="image_embeds", + data=image_embeds) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + + if pixel_values_videos is None: + return None + + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2VLVideoInputs( + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + def _process_image_input(self, + image_input: Qwen2VLImageInputs) -> torch.Tensor: + if image_input["type"] == "image_embeds": + return image_input["data"].type(self.visual.dtype) + + pixel_values = image_input["data"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, + grid_thw=image_input["image_grid_thw"]) + return image_embeds + + def _process_video_input(self, + video_input: Qwen2VLVideoInputs) -> torch.Tensor: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, + grid_thw=video_input["video_grid_thw"]) + return video_embeds + + def _merge_multimodal_embeddings( + self, + input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: torch.Tensor, + placeholder_token_id: int, + ) -> torch.Tensor: + mask = (input_ids == placeholder_token_id) + inputs_embeds[mask, :] = multimodal_embeddings + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen2-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + rope_scaling = getattr(self.config, "rope_scaling", {}) + if rope_scaling.get("type", None) == "mrope": + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.model.embed_tokens(input_ids) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + + input_ids = None + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "up_proj", 1), + ("gate_up_proj", "gate_proj", 0), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if "visual" in name and "qkv.weight" in name: + visual_num_heads = self.config.vision_config.num_heads + visual_embed_dim = self.config.vision_config.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size, + visual_embed_dim) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) + elif "visual" in name and "qkv.bias" in name: + visual_num_heads = self.config.vision_config.num_heads + visual_embed_dim = self.config.vision_config.embed_dim + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + try: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + except KeyError: + raise ValueError(f"Unexpected weight: {name}") from None + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen3.py b/vllm/model_executor/models/qwen3.py new file mode 100644 index 0000000..87a08b2 --- /dev/null +++ b/vllm/model_executor/models/qwen3.py @@ -0,0 +1,357 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import Qwen3Config + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, LoRAConfig#VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .qwen2 import Qwen2MLP as Qwen3MLP +from .qwen2 import Qwen2Model +from .utils import AutoWeightsLoader, PPMissingLayer, maybe_prefix + +logger = init_logger(__name__) + + +class Qwen3Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None, + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + )#attn_type=attn_type) + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen3 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + + self.self_attn = Qwen3Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling, + prefix=f"{prefix}.self_attn", + attn_type=attn_type, + ) + self.mlp = Qwen3MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + #prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Qwen3DecoderLayer, +} + + +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl, + # otherwise (seq_len, ). + "positions": -1, + "intermediate_tensors": 0, + "inputs_embeds": 0, + }) +class Qwen3Model(Qwen2Model): + + def __init__(self, *, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + decoder_layer_type=Qwen3DecoderLayer) + + +class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + def __init__(self, *, + config: Qwen3Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = ""): + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen3Model(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "model")) + + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + return loader.load_weights(weights) \ No newline at end of file diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py new file mode 100644 index 0000000..26038f3 --- /dev/null +++ b/vllm/model_executor/models/qwen3_moe.py @@ -0,0 +1,538 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen3MoE model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig#, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (extract_layer_index, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + +logger = init_logger(__name__) + + +class Qwen3MoeMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen3MoeSparseMoeBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.experts = FusedMoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config) + + self.gate = ReplicatedLinear(config.hidden_size, + config.num_experts, + bias=False, + quant_config=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_dim = hidden_states.shape[-1] + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) + final_hidden_states = final_hidden_states + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(orig_shape) + + +class Qwen3MoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + head_dim: Optional[int] = None, + rms_norm_eps: float = 1e-06, + qkv_bias: bool = False, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim or (hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear(hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=qkv_bias, + quant_config=quant_config) + + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # Add qk-norm + q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, + self.head_dim) + q_by_head = self.q_norm.forward_native(q_by_head) + q = q_by_head.view(q.shape) + + k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, + self.head_dim) + k_by_head = self.k_norm.forward_native(k_by_head) + k = k_by_head.view(k.shape) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen3MoeDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = Qwen3MoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + rms_norm_eps=config.rms_norm_eps, + qkv_bias=getattr(config, 'attention_bias', False), + head_dim=getattr(config, 'head_dim', None), + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) + mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else + config.mlp_only_layers) + if (layer_idx not in mlp_only_layers) and ( + config.num_experts > 0 and + (layer_idx + 1) % config.decoder_sparse_step == 0): + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config) + else: + self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + +class Qwen3MoeModel(nn.Module): + + def __init__(self, *, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen3MoeDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen3MoeForCausalLM(nn.Module, SupportsPP): + + fall_back_to_pt_during_load = False + + def __init__(self, *, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.config = config + self.quant_config = quant_config + + self.model = Qwen3MoeModel(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + if name.endswith("kv_scale"): + remapped_kv_scale_name = name.replace( + ".kv_scale", ".attn.kv_scale") + if remapped_kv_scale_name not in params_dict: + logger.warning_once( + "Found kv scale in the checkpoint " + f"(e.g. {name}), but not found the expected " + f"name in the model " + f"(e.g. {remapped_kv_scale_name}). " + "kv-scale is not loaded.") + continue + else: + name = remapped_kv_scale_name + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params \ No newline at end of file diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py new file mode 100644 index 0000000..73ae906 --- /dev/null +++ b/vllm/model_executor/models/registry.py @@ -0,0 +1,453 @@ +import importlib +import pickle +import subprocess +import sys +import tempfile +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from functools import lru_cache +from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import cloudpickle +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip + +from .interfaces import (has_inner_state, is_attention_free, + supports_multimodal, supports_pp) +from .interfaces_base import is_embedding_model, is_text_generation_model + +logger = init_logger(__name__) + +# yapf: disable +_TEXT_GENERATION_MODELS = { + # [Decoder-only] + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + # ChatGLMModel supports multimodal + "CohereForCausalLM": ("commandr", "CohereForCausalLM"), + "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), + "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), + "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MambaForCausalLM": ("mamba", "MambaForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), + "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + # QWenLMHeadModel supports multimodal + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), + "Qwen3ForCausalLM": ("qwen3", "Qwen3ForCausalLM"), + "Qwen3MoeForCausalLM": ("qwen3_moe", "Qwen3MoeForCausalLM"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), + "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + # [Encoder-decoder] + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), +} + +_MULTIMODAL_MODELS = { + # [Decoder-only] + "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), + "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "InternVLChatModel": ("internvl", "InternVLChatModel"), + "LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 + "LlavaNextVideoForConditionalGeneration": ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), # noqa: E501 + "LlavaOnevisionForConditionalGeneration": ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), # noqa: E501 + "MiniCPMV": ("minicpmv", "MiniCPMV"), + "MolmoForCausalLM": ("molmo", "MolmoForCausalLM"), + "NVLM_D": ("nvlm_d", "NVLM_D_Model"), + "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), # noqa: E501 + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + "UltravoxModel": ("ultravox", "UltravoxModel"), + # [Encoder-decoder] + "MllamaForConditionalGeneration": ("mllama", "MllamaForConditionalGeneration"), # noqa: E501 +} + +_SPECULATIVE_DECODING_MODELS = { + "EAGLEModel": ("eagle", "EAGLE"), + "MedusaModel": ("medusa", "Medusa"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), +} +# yapf: enable + +_VLLM_MODELS = { + **_TEXT_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_MULTIMODAL_MODELS, + **_SPECULATIVE_DECODING_MODELS, +} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS: List[str] = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") +} + + +@dataclass(frozen=True) +class _ModelInfo: + is_text_generation_model: bool + is_embedding_model: bool + supports_multimodal: bool + supports_pp: bool + has_inner_state: bool + is_attention_free: bool + + @staticmethod + def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": + return _ModelInfo( + is_text_generation_model=is_text_generation_model(model), + is_embedding_model=is_embedding_model(model), + supports_multimodal=supports_multimodal(model), + supports_pp=supports_pp(model), + has_inner_state=has_inner_state(model), + is_attention_free=is_attention_free(model), + ) + + +class _BaseRegisteredModel(ABC): + + @abstractmethod + def inspect_model_cls(self) -> _ModelInfo: + raise NotImplementedError + + @abstractmethod + def load_model_cls(self) -> Type[nn.Module]: + raise NotImplementedError + + +@dataclass(frozen=True) +class _RegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has already been imported in the main process. + """ + + interfaces: _ModelInfo + model_cls: Type[nn.Module] + + @staticmethod + def from_model_cls(model_cls: Type[nn.Module]): + return _RegisteredModel( + interfaces=_ModelInfo.from_model_cls(model_cls), + model_cls=model_cls, + ) + + def inspect_model_cls(self) -> _ModelInfo: + return self.interfaces + + def load_model_cls(self) -> Type[nn.Module]: + return self.model_cls + + +@dataclass(frozen=True) +class _LazyRegisteredModel(_BaseRegisteredModel): + """ + Represents a model that has not been imported in the main process. + """ + module_name: str + class_name: str + + # Performed in another process to avoid initializing CUDA + def inspect_model_cls(self) -> _ModelInfo: + return _run_in_subprocess( + lambda: _ModelInfo.from_model_cls(self.load_model_cls())) + + def load_model_cls(self) -> Type[nn.Module]: + mod = importlib.import_module(self.module_name) + return getattr(mod, self.class_name) + + +@lru_cache(maxsize=128) +def _try_load_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> Optional[Type[nn.Module]]: + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError(f"Model architecture '{model_arch}' is not " + "supported by ROCm for now.") + + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + msg = _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch] + logger.warning( + "Model architecture '%s' is partially " + "supported by ROCm: %s", model_arch, msg) + + try: + return model.load_model_cls() + except Exception: + logger.exception("Error in loading model architecture '%s'", + model_arch) + return None + + +@lru_cache(maxsize=128) +def _try_inspect_model_cls( + model_arch: str, + model: _BaseRegisteredModel, +) -> Optional[_ModelInfo]: + try: + return model.inspect_model_cls() + except Exception: + logger.exception("Error in inspecting model architecture '%s'", + model_arch) + return None + + +@dataclass +class _ModelRegistry: + # Keyed by model_arch + models: Dict[str, _BaseRegisteredModel] = field(default_factory=dict) + + def get_supported_archs(self) -> List[str]: + return list(self.models.keys()) + + def register_model( + self, + model_arch: str, + model_cls: Union[Type[nn.Module], str], + ) -> None: + """ + Register an external model to be used in vLLM. + + :code:`model_cls` can be either: + + - A :class:`torch.nn.Module` class directly referencing the model. + - A string in the format :code:`:` which can be used to + lazily import the model. This is useful to avoid initializing CUDA + when importing the model and thus the related error + :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`. + """ + if model_arch in self.models: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls) + + if isinstance(model_cls, str): + split_str = model_cls.split(":") + if len(split_str) != 2: + msg = "Expected a string in the format `:`" + raise ValueError(msg) + + model = _LazyRegisteredModel(*split_str) + else: + model = _RegisteredModel.from_model_cls(model_cls) + + self.models[model_arch] = model + + def _raise_for_unsupported(self, architectures: List[str]): + all_supported_archs = self.get_supported_archs() + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {all_supported_archs}") + + def _try_load_model_cls(self, + model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in self.models: + return None + + return _try_load_model_cls(model_arch, self.models[model_arch]) + + def _try_inspect_model_cls(self, model_arch: str) -> Optional[_ModelInfo]: + if model_arch not in self.models: + return None + + return _try_inspect_model_cls(model_arch, self.models[model_arch]) + + def _normalize_archs( + self, + architectures: Union[str, List[str]], + ) -> List[str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return architectures + + def inspect_model_cls( + self, + architectures: Union[str, List[str]], + ) -> _ModelInfo: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_info = self._try_inspect_model_cls(arch) + if model_info is not None: + return model_info + + return self._raise_for_unsupported(architectures) + + def resolve_model_cls( + self, + architectures: Union[str, List[str]], + ) -> Tuple[Type[nn.Module], str]: + architectures = self._normalize_archs(architectures) + + for arch in architectures: + model_cls = self._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + return self._raise_for_unsupported(architectures) + + def is_text_generation_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).is_text_generation_model + + def is_embedding_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).is_embedding_model + + def is_multimodal_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_multimodal + + def is_pp_supported_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_pp + + def model_has_inner_state(self, architectures: Union[str, + List[str]]) -> bool: + return self.inspect_model_cls(architectures).has_inner_state + + def is_attention_free_model(self, architectures: Union[str, + List[str]]) -> bool: + return self.inspect_model_cls(architectures).is_attention_free + + +ModelRegistry = _ModelRegistry({ + model_arch: _LazyRegisteredModel( + module_name=f"vllm.model_executor.models.{mod_relname}", + class_name=cls_name, + ) + for model_arch, (mod_relname, cls_name) in _VLLM_MODELS.items() +}) + +_T = TypeVar("_T") + + +def _run_in_subprocess(fn: Callable[[], _T]) -> _T: + with tempfile.NamedTemporaryFile() as output_file: + # `cloudpickle` allows pickling lambda functions directly + input_bytes = cloudpickle.dumps((fn, output_file.name)) + + # cannot use `sys.executable __file__` here because the script + # contains relative imports + returned = subprocess.run( + [sys.executable, "-m", "vllm.model_executor.models.registry"], + input=input_bytes, + capture_output=True) + + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError(f"Error raised in subprocess:\n" + f"{returned.stderr.decode()}") from e + + with open(output_file.name, "rb") as f: + return pickle.load(f) + + +def _run() -> None: + # Setup plugins + from vllm.plugins import load_general_plugins + load_general_plugins() + + fn, output_file = pickle.loads(sys.stdin.buffer.read()) + + result = fn() + + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) + + +if __name__ == "__main__": + _run() \ No newline at end of file diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py new file mode 100644 index 0000000..b2d081c --- /dev/null +++ b/vllm/model_executor/models/siglip.py @@ -0,0 +1,590 @@ +"""Implementation of SiglipVisionModel intended to be only used +within a vision language model.""" + +import math +from typing import Iterable, List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL import Image +from torch import nn +from transformers import SiglipVisionConfig +from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention + +from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.inputs import LLMInputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import SequenceData + +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + +USE_XFORMERS_OPS = True + + +def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: + # Since interpolation is applied, the image size need not be divisible + # assert image_size % patch_size == 0 + return image_size // patch_size + + +def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int: + grid_length = get_siglip_patch_grid_length(image_size=image_size, + patch_size=patch_size) + return grid_length * grid_length + + +def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int: + return get_siglip_num_patches(image_size=hf_config.image_size, + patch_size=hf_config.patch_size) + + +def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int: + return get_siglip_image_feature_size(hf_config) + + +def dummy_seq_data_for_siglip( + hf_config: SiglipVisionConfig, + seq_len: int, + num_images: int, + *, + image_token_id: int, + image_feature_size_override: Optional[int] = None, +): + if image_feature_size_override is None: + image_feature_size = get_siglip_image_feature_size(hf_config) + else: + image_feature_size = image_feature_size_override + + return SequenceData.from_token_counts( + (image_token_id, image_feature_size * num_images), + (0, seq_len - image_feature_size * num_images), + ) + + +def dummy_image_for_siglip( + hf_config: SiglipVisionConfig, + num_images: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + width = height = hf_config.image_size + if image_width_override is not None: + width = image_width_override + if image_height_override is not None: + height = image_height_override + + image = Image.new("RGB", (width, height), color=0) + return {"image": image if num_images == 1 else [image] * num_images} + + +def dummy_video_for_siglip( + hf_config: SiglipVisionConfig, + num_frames: int, + *, + image_width_override: Optional[int] = None, + image_height_override: Optional[int] = None, +): + pil_frame = dummy_image_for_siglip( + hf_config, + num_images=1, + image_width_override=image_width_override, + image_height_override=image_height_override) + np_frame = np.array(pil_frame["image"]) + mm_data_per_video = np.repeat([np_frame], num_frames, axis=0) + mm_data = {"video": mm_data_per_video} + return mm_data + + +def input_processor_for_siglip( + model_config: ModelConfig, + hf_config: SiglipVisionConfig, + llm_inputs: LLMInputs, + *, + image_token_id: int, + image_feature_size_override: Optional[Union[int, List[int]]] = None, +): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + + if image_feature_size_override is None: + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_siglip_image_feature_size(hf_config) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + else: + raise TypeError(f"Invalid image type: {type(image_data)}") + else: + image_feature_size = image_feature_size_override + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=image_token_id, + repeat_count=image_feature_size, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data, + ) + + +# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa +class SiglipVisionEmbeddings(nn.Module): + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches = (self.image_size // self.patch_size)**2 + self.num_positions = self.num_patches + self.position_embedding = VocabParallelEmbedding( + self.num_positions, self.embed_dim) + self.register_buffer( + "position_ids", + torch.arange(self.num_positions, dtype=torch.int64).expand( + (1, -1)), + persistent=False, + ) + + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, + width: int) -> torch.Tensor: + """ + This method is an adapted method for SigLIP (due to SigLIP not having + class embedding unlike other ViTs) that allows the model to interpolate + the pre-trained position encodings such that it can be usable on higher + resolution images. + + Source: + https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 + """ + position_embeddings = self.position_embedding.weight.unsqueeze(0) + num_patches = embeddings.shape[1] + num_positions = position_embeddings.shape[1] + if num_patches == num_positions and height == width: + return position_embeddings + + dim = embeddings.shape[-1] + height = height // self.patch_size + width = width // self.patch_size + # we add a small number to avoid floating point error + # in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + height, width = height + 0.1, width + 0.1 + + patch_pos_embed = position_embeddings.reshape( + 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), + dim) + patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed, + scale_factor=( + height / math.sqrt(num_positions), + width / math.sqrt(num_positions), + ), + mode="bicubic", + align_corners=False, + ) + if (int(height) != patch_pos_embed.shape[-2] + or int(width) != patch_pos_embed.shape[-1]): + raise ValueError("Width or height does not match with " + "the interpolated position embeddings") + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed + + def forward(self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False) -> torch.Tensor: + _, _, height, width = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(pixel_values.to( + dtype=target_dtype)) # shape = [*, width, grid, grid] + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if interpolate_pos_encoding: + embeddings = embeddings + self.interpolate_pos_encoding( + embeddings, height, width) + else: + embeddings = embeddings + self.position_embedding( + self.position_ids) + return embeddings + + +class SiglipParallelAttention(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError(f"embed_dim must be divisible by num_heads (got " + "`embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + """Input shape: Batch x Time x Channel""" + batch_size, q_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + query_states = query_states.transpose(1,2) + key_states = key_states.transpose(1,2) + value_states = value_states.transpose(1,2) + attn = (query_states @ (key_states * self.scale).permute(0,1,3,2)).float() + attn = torch.softmax(attn, dim=-1).to(value_states.dtype) + out = attn @ value_states + out = out.transpose(1,2).reshape(batch_size, q_len, -1) + + # out = xops.memory_efficient_attention_forward(query_states, + # key_states, + # value_states, + # p=self.dropout, + # scale=self.scale) + # out = out.contiguous().view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) + + return attn_output, None + + +class SiglipMLP(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + + # For quantization, we require the hidden size to be a multiple of 64 + quantizable = (config.hidden_size % 64 == 0 + and config.intermediate_size % 64 == 0) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + quant_config=quant_config if quantizable else None, + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + quant_config=quant_config if quantizable else None, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class SiglipEncoderLayer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.embed_dim = config.hidden_size + + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = SiglipParallelAttention(config, + quant_config=quant_config) + else: + self.self_attn = SiglipSdpaAttention(config) + + self.layer_norm1 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP( + config, + quant_config=quant_config, + ) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, + eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> Tuple[torch.Tensor, None]: + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, _ = self.self_attn(hidden_states=hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, None + + +class SiglipEncoder(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + ): + super().__init__() + self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + + self.layers = nn.ModuleList([ + SiglipEncoderLayer(config, quant_config=quant_config) + for _ in range(num_hidden_layers) + ]) + + def forward( + self, + inputs_embeds: torch.Tensor, + ) -> torch.Tensor: + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states, _ = encoder_layer(hidden_states) + + return hidden_states + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention + self.attention = torch.nn.MultiheadAttention( + config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config=config, quant_config=quant_config) + + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +class SiglipVisionTransformer(nn.Module): + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + ): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( + config, + quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + ) + + if len(self.encoder.layers) > config.num_hidden_layers: + raise ValueError( + f"The original encoder only has {config.num_hidden_layers} " + f"layers, but you requested {len(self.encoder.layers)} layers." + ) + elif len(self.encoder.layers) == config.num_hidden_layers: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + # post_layernorm is unused when we extract intermediate features + # In this case, we can skip it to conserve memory + self.post_layernorm = None + + self.use_head = (True if not hasattr(config, "vision_use_head") else + config.vision_use_head) + if self.use_head: + self.head = SiglipMultiheadAttentionPoolingHead( + config=config, quant_config=quant_config) + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = True, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + encoder_outputs = self.encoder(inputs_embeds=hidden_states) + + if self.post_layernorm is None: + return encoder_outputs + + last_hidden_state = self.post_layernorm(encoder_outputs) + # TODO: add this back when pooled_output is used in inference + # if self.use_head: + # pooled_output = self.head(last_hidden_state) + + return last_hidden_state + + +class SiglipVisionModel(nn.Module): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__( + self, + config: SiglipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + num_hidden_layers_override: Optional[int] = None, + ): + super().__init__() + + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + + self.vision_model = SiglipVisionTransformer( + config, + quant_config, + num_hidden_layers_override=num_hidden_layers_override, + ) + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + def forward( + self, + pixel_values: torch.Tensor, + interpolate_pos_encoding: bool = False, + ) -> torch.Tensor: + return self.vision_model( + pixel_values=pixel_values, + interpolate_pos_encoding=interpolate_pos_encoding, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] + params_dict = dict(self.named_parameters()) + layer_count = len(self.vision_model.encoder.layers) + + for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if (name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py new file mode 100644 index 0000000..b9298ed --- /dev/null +++ b/vllm/model_executor/models/solar.py @@ -0,0 +1,569 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Solar model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class SolarMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class SolarAttention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class SolarDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] \ + = config.original_max_position_embeddings + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = SolarAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = SolarMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class SolarModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: SolarDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + bskcn_h_1 = None + bskcn_h_2 = None + bskcn_r_1 = None + bskcn_r_2 = None + bskcn_tv = (self.config.bskcn_tv[0] + if self.training else self.config.bskcn_tv[1]) + + for i in range(self.start_layer, self.end_layer): + if i in self.config.bskcn_1: + bskcn_h_1 = hidden_states.clone() + bskcn_r_1 = residual.clone() + if i in self.config.bskcn_2: + bskcn_h_2 = hidden_states.clone() + bskcn_r_2 = residual.clone() + if i in self.config.bskcn_3: + hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv) + if i in self.config.bskcn_4: + hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * ( + 1 - bskcn_tv) + residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv) + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = SolarModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py new file mode 100644 index 0000000..083a485 --- /dev/null +++ b/vllm/model_executor/models/stablelm.py @@ -0,0 +1,338 @@ +# coding=utf-8 +# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This code is based off the following work: +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json +"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) +model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class StablelmMLP(nn.Module): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, [config.intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=False) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class StablelmAttention(nn.Module): + + def __init__(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + + self.total_num_key_value_heads = config.num_key_value_heads + if self.total_num_key_value_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_key_value_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_key_value_heads == 0 + self.num_key_value_heads = max( + 1, self.total_num_key_value_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + rope_pct = getattr(config, "rope_pct", + getattr(config, "partial_rotary_factor", 1)) + self.rotary_ndims = int(self.head_dim * rope_pct) + self.scaling = self.head_dim**-0.5 + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_key_value_heads * self.head_dim + self.qkv_bias = getattr(config, "use_qkv_bias", False) + if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads " + f"(got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_key_value_heads, + self.qkv_bias, + quant_config=quant_config) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_ndims, + max_position=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class StablelmDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.self_attn = StablelmAttention(config, cache_config, quant_config) + self.mlp = StablelmMLP(config, quant_config) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +class StableLMEpochModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '') -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: StablelmDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers", + ) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class StablelmForCausalLM(nn.Module, SupportsPP): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = StableLMEpochModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py new file mode 100644 index 0000000..81dd7c4 --- /dev/null +++ b/vllm/model_executor/models/starcoder2.py @@ -0,0 +1,336 @@ +# coding=utf-8 +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Starcoder2 model.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import Starcoder2Config + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class Starcoder2Attention(nn.Module): + + def __init__(self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + self.max_position_embeddings = config.max_position_embeddings + self.use_bias = config.use_bias + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.use_bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=self.use_bias, + quant_config=quant_config, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Starcoder2MLP(nn.Module): + + def __init__(self, + config: Starcoder2Config, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.c_fc = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=config.use_bias, + quant_config=quant_config, + ) + self.c_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=config.use_bias, + quant_config=quant_config, + ) + self.act = get_act_fn(config.hidden_act, quant_config, + config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class Starcoder2DecoderLayer(nn.Module): + + def __init__(self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Starcoder2Attention(config, + cache_config, + quant_config=quant_config) + self.mlp = Starcoder2MLP(config, quant_config=quant_config) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Starcoder2Model(nn.Module): + + def __init__(self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # TODO: consider padding_idx (currently removed) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Starcoder2DecoderLayer( + config, cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Starcoder2ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, + config: Starcoder2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.model = Starcoder2Model(config, + cache_config, + quant_config=quant_config) + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py new file mode 100644 index 0000000..e162e3a --- /dev/null +++ b/vllm/model_executor/models/ultravox.py @@ -0,0 +1,503 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py +"""PyTorch Ultravox model.""" + +import math +from array import array +from functools import cached_property, lru_cache +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union, cast) + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from transformers.models.whisper import WhisperFeatureExtractor +from transformers.models.whisper.modeling_whisper import WhisperEncoder + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.inputs.data import LLMInputs +from vllm.inputs.registry import InputContext +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs, NestedTensors +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) +from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.utils import is_list_of + +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, merge_multimodal_embeddings) + +_AUDIO_PLACEHOLDER_TOKEN = 128002 +_AUDIO_TOKENS_PER_SECOND = 6.25 + + +class UltravoxAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: NestedTensors + """Shape: `(batch_size, num_audios, 80, M)""" + + +class UltravoxAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" + + +UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, + UltravoxAudioEmbeddingInputs] + + +@lru_cache +def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor: + return WhisperFeatureExtractor.from_pretrained(model_id) + + +def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor: + return cached_feature_extractor( + ctx.get_hf_config(UltravoxConfig).audio_model_id) + + +def get_ultravox_max_audio_tokens(ctx: InputContext): + feature_extractor = whisper_feature_extractor(ctx) + return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) + + +def dummy_seq_data_for_ultravox( + ctx: InputContext, + seq_len: int, + audio_count: int, +): + audio_placeholder = array( + VLLM_TOKEN_ID_ARRAY_TYPE, + [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) + + # Add a separator between each chunk. + audio_token_ids = (audio_placeholder + + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count + other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - len(audio_token_ids)) + + return SequenceData(audio_token_ids + other_token_ids) + + +def dummy_audio_for_ultravox( + ctx: InputContext, + audio_count: int, +): + feature_extractor = whisper_feature_extractor(ctx) + audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) + return {"audio": [audio_and_sr] * audio_count} + + +def dummy_data_for_ultravox( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +): + audio_count = mm_counts["audio"] + seq_data = dummy_seq_data_for_ultravox(ctx, seq_len, audio_count) + mm_dict = dummy_audio_for_ultravox(ctx, audio_count) + + return (seq_data, mm_dict) + + +def input_mapper_for_ultravox(ctx: InputContext, data: object): + if not isinstance(data, list): + data = [data] + + # If the audio inputs are embeddings, no need for preprocessing + if is_list_of(data, torch.Tensor, check="all"): + return MultiModalInputs({"audio_embeds": data}) + + audio_features = [] + for audio_input in data: + if not isinstance(audio_input, tuple): + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}") + + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input) + feature_extractor = whisper_feature_extractor(ctx) + + if sr != feature_extractor.sampling_rate: + try: + import librosa + except ImportError: + raise ImportError( + "Please install vllm[audio] for audio support.") from None + audio = librosa.resample(audio, + orig_sr=sr, + target_sr=feature_extractor.sampling_rate) + sr = feature_extractor.sampling_rate + + minimum_audio_length = feature_extractor.n_fft // 2 + 1 + if len(audio) < minimum_audio_length: + # Not enough audio; pad it. + audio = np.pad(audio, (0, minimum_audio_length - len(audio))) + + single_audio_features = feature_extractor( + audio, sampling_rate=sr, padding="longest", + return_tensors="pt")["input_features"] + + # Remove the batch dimension because we're wrapping it in a list. + audio_features.append(single_audio_features.squeeze(0)) + + return MultiModalInputs({"audio_features": audio_features}) + + +def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "audio" not in multi_modal_data: + return llm_inputs + + feature_extractor = whisper_feature_extractor(ctx) + audios = multi_modal_data["audio"] + if not isinstance(audios, list): + audios = [audios] + + audio_token_counts = [] + for audio in audios: + if isinstance(audio, torch.Tensor): + audio_num_tokens = audio.shape[1] + audio_token_counts.append(audio_num_tokens) + else: + audio_data, sample_rate = audio + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - (feature_extractor.hop_length - 1)) / + feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + audio_token_counts.append(audio_num_tokens) + + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, + repeat_count=audio_token_counts, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class StackAudioFrames(nn.Module): + """ + Stack the audio embedding frames to reduce the sequence length by a factor + of `stack_factor`. + """ + + def __init__(self, stack_factor: int = 8): + super().__init__() + self.stack_factor = stack_factor + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + B, T, C = audio_embeds.shape + T_pad = (T + self.stack_factor - + 1) // self.stack_factor * self.stack_factor + audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) + B, T, C = audio_embeds.shape + audio_embeds = audio_embeds.view(B, T // self.stack_factor, + C * self.stack_factor) + return audio_embeds + + +class FlippedSiluAndMul(SiluAndMul): + """Ultravox is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + + +class UltravoxProjector(nn.Module): + + def __init__(self, config: UltravoxConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim) + self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) + dim = self.hidden_dim + + if config.projector_act == "swiglu": + self.act = FlippedSiluAndMul() + dim = dim // 2 + else: + self.act = get_act_fn(config.projector_act) + + self.linear_2 = nn.Linear(dim, + config.text_config.hidden_size, + bias=False) + self.ln_post = RMSNorm(config.text_config.hidden_size) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + audio_features = self.ln_pre(audio_features) + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.ln_post(hidden_states) + return hidden_states + + +class ModifiedWhisperEncoder(WhisperEncoder): + """ + Encoder portion of OpenAI's Whisper model. + + This implementation is a slightly modified version of HF Transformers' + Whisper Encoder, with only a few fixes: + 1. base_model_prefix updated to allow for doing `.from_pretrained` + directly on the encoder + 2. allow less than 30 second of audio padding to be passed in: + - relaxed ValueError check for `input_features` length to be less + than or equal to `expected_seq_length` instead of strictly equal + - embed_pos is now sliced to match the length of `inputs_embeds` + + Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py + See commentary: https://github.com/huggingface/transformers/issues/25744 + """ + + base_model_prefix = "model.encoder" + + def forward( + self, + input_features, + ): + expected_seq_length = (self.config.max_source_positions * + self.conv1.stride[0] * self.conv2.stride[0]) + if input_features.shape[-1] > expected_seq_length: + raise ValueError( + f"Whisper expects the mel input features to be of length " + f"{expected_seq_length} or less, but found " + f"{input_features.shape[-1]}. Make sure to pad the input mel " + f"features to {expected_seq_length}.") + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=None, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_ultravox_max_audio_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) +@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) +class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): + + def __init__(self, + config: UltravoxConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional["QuantizationConfig"] = None): + super().__init__() + self.config = config + self.multi_modal_config = multimodal_config + assert self.multi_modal_config + + self.secondary_weights = [] + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + if config.audio_model_id is not None: + self.secondary_weights.append( + DefaultModelLoader.Source( + model_or_path=config.audio_model_id, + revision=None, + prefix="audio_tower.", + )) + self.multi_modal_projector = UltravoxProjector(config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + if config.text_model_id is not None: + self.secondary_weights.append( + DefaultModelLoader.Source(model_or_path=config.text_model_id, + revision=None, + prefix="language_model.")) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + + def _audio_features_to_embeddings( + self, input_features: torch.Tensor) -> torch.Tensor: + audio_input = input_features.to(self.audio_tower.dtype) + audio_features = self.audio_tower(audio_input) + audio_features = audio_features.to(self.audio_tower.dtype) + audio_embeddings = self.multi_modal_projector(audio_features) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[UltravoxAudioInputs]: + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + + return UltravoxAudioFeatureInputs(type="audio_features", + data=audio_features) + + if audio_embeds is not None: + if not isinstance(audio_embeds, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return UltravoxAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: UltravoxAudioInputs) -> NestedTensors: + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + if isinstance(audio_features, torch.Tensor): + # Combine the B and N dimensions for the encoder/projector + flattened = flatten_bn(audio_features) + flattened_embeddings = self._audio_features_to_embeddings( + flattened) + + # Restore the original dimensions + embeddings = flattened_embeddings.unflatten( + 0, audio_features.shape[:2]) + return embeddings + + result = [] + # TODO: Batch heterogeneous tensors through the encoder/projector + for audio_features_item in audio_features: + if isinstance(audio_features_item, torch.Tensor): + result.append( + self._audio_features_to_embeddings(audio_features_item)) + else: + embeddings = [ + # Add a batch dimension to embed it, then remove it. + self._audio_features_to_embeddings(tensor.unsqueeze(0) + ).squeeze(0) + for tensor in audio_features_item + ] + result.append(embeddings) + + return result + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[torch.Tensor], + **kwargs) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Ultravox + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted audio embeddings. The to-be-inserted + audio has a size that is essentially 6.25 tokens per second of audio. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + audio_features: A batch of audio inputs [B, N, 80, M]. + """ + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is not None: + audio_embeddings = self._process_audio_input(audio_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, audio_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."}) + + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["audio_tower."]) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py new file mode 100644 index 0000000..810da7a --- /dev/null +++ b/vllm/model_executor/models/utils.py @@ -0,0 +1,515 @@ +import itertools +from dataclasses import dataclass, field +from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, + Protocol, Tuple, Union, overload) + +import torch +import torch.nn as nn +from torch.func import functional_call +from transformers import PretrainedConfig + +from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig, + SchedulerConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.model_loader.loader import build_model +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models import ModelRegistry +from vllm.multimodal.base import NestedTensors +from vllm.sequence import IntermediateTensors +from vllm.utils import is_pin_memory_available + +WeightsMapping = Mapping[str, Optional[str]] +"""If a key maps to a value of `None`, the corresponding weight is ignored.""" + + +@dataclass +class WeightsMapper: + """Maps the name of each weight if they match the following patterns.""" + + orig_to_new_substr: WeightsMapping = field(default_factory=dict) + orig_to_new_prefix: WeightsMapping = field(default_factory=dict) + orig_to_new_suffix: WeightsMapping = field(default_factory=dict) + + def _map_name(self, key: str) -> Optional[str]: + for substr, new_key in self.orig_to_new_substr.items(): + if substr in key: + if new_key is None: + return None + + key = key.replace(substr, new_key, 1) + + for prefix, new_key in self.orig_to_new_prefix.items(): + if key.startswith(prefix): + if new_key is None: + return None + + key = key.replace(prefix, new_key, 1) + + for suffix, new_key in self.orig_to_new_suffix.items(): + if key.endswith(suffix): + if new_key is None: + return None + + key = new_key.join(key.rsplit(suffix, 1)) + + return key + + def apply( + self, weights: Iterable[Tuple[str, torch.Tensor]] + ) -> Iterable[Tuple[str, torch.Tensor]]: + return ((out_name, data) for name, data in weights + if (out_name := self._map_name(name)) is not None) + + +class AutoWeightsLoader: + """ + Helper class to load weights into a :class:`torch.nn.Module`. It is able + to automatically detect child modules and parameters while iterating over + the weights only once. + + The weight loading logic for individual modules can be overridden + by defining a ``load_weights`` method. + + Similarly, the weight loading logic for individual parameters can be + overridden by defining a ``weight_loader`` method. + """ + + def __init__( + self, + module: nn.Module, + *, + skip_prefixes: Optional[List[str]] = None, + ignore_unexpected_prefixes: Optional[List[str]] = None, + ) -> None: + super().__init__() + + self.module = module + self.skip_prefixes = skip_prefixes or [] + self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or [] + + def _groupby_prefix( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]: + weights_by_parts = ((weight_name.split(".", 1), weight_data) + for weight_name, weight_data in weights) + + for prefix, group in itertools.groupby(weights_by_parts, + key=lambda x: x[0][0]): + yield ( + prefix, + # Because maxsplit=1 in weight_name.split(...), + # the length of `parts` must either be 1 or 2 + (("" if len(parts) == 1 else parts[1], weights_data) + for parts, weights_data in group), + ) + + def _get_qualname(self, prefix: str, rest: str) -> str: + if prefix == "": + return rest + if rest == "": + return prefix + + return ".".join((prefix, rest)) + + def _can_skip(self, qualname: str) -> bool: + return any(qualname.startswith(p) for p in self.skip_prefixes) + + def _can_ignore_unexpected(self, qualname: str) -> bool: + return any( + qualname.startswith(p) for p in self.ignore_unexpected_prefixes) + + def _load_param( + self, + base_prefix: str, + param: nn.Parameter, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> None: + for weight_name, weight_data in weights: + weight_qualname = self._get_qualname(base_prefix, weight_name) + + if self._can_skip(weight_qualname): + continue + + if weight_name != "": + if not self._can_ignore_unexpected(weight_qualname): + raise ValueError( + f"Attempted to load nested weight '{weight_qualname}' " + f"into a single parameter '{base_prefix}'") + + continue + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, weight_data) + + def _load_module( + self, + base_prefix: str, + module: nn.Module, + weights: Iterable[Tuple[str, torch.Tensor]], + ) -> None: + if isinstance(module, PPMissingLayer): + return + + # Avoid infinite recursion since this function is typically + # called inside load_weights of the module itself + if module != self.module: + module_load_weights = getattr(module, "load_weights", None) + if callable(module_load_weights): + module_load_weights(weights) + return + + child_modules = dict(module.named_children()) + child_params = dict(module.named_parameters(recurse=False)) + + for child_prefix, child_weights in self._groupby_prefix(weights): + prefix = self._get_qualname(base_prefix, child_prefix) + + if self._can_skip(prefix): + continue + + if child_prefix in child_modules: + self._load_module(prefix, child_modules[child_prefix], + child_weights) + elif child_prefix in child_params: + self._load_param(prefix, child_params[child_prefix], + child_weights) + else: + if not self._can_ignore_unexpected(prefix): + msg = f"There is no module or parameter named '{prefix}'" + raise ValueError(msg) + + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + *, + mapper: Optional[WeightsMapper] = None, + ) -> None: + if mapper is not None: + weights = mapper.apply(weights) + + self._load_module("", self.module, weights) + + +def init_vllm_registered_model( + hf_config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + *, + lora_config: Optional[LoRAConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + scheduler_config: Optional[SchedulerConfig] = None, +) -> nn.Module: + """ + Helper function to initialize an inner model registered to vLLM, + based on the arguments passed to the outer vLLM model. + """ + model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures) + + return build_model( + model_class, + hf_config, + cache_config, + quant_config, + lora_config=lora_config, + multimodal_config=multimodal_config, + scheduler_config=scheduler_config, + ) + + +@overload +def flatten_bn(x: torch.Tensor) -> torch.Tensor: + ... + + +@overload +def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: + ... + + +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: Literal[True], +) -> torch.Tensor: + ... + + +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. + + The input tensor should have shape ``(B, N, ...)```. + """ + if isinstance(x, torch.Tensor): + return x.flatten(0, 1) + + if concat: + return torch.cat(x) + + return [x_n for x_b in x for x_n in x_b] + + +def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: + """ + Recursively flattens and concatenates NestedTensors on all but the last + dimension. + """ + + if isinstance(embeddings, torch.Tensor): + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) + + return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) + + +def _embedding_count_expression(embeddings: NestedTensors) -> str: + """ + Constructs a debugging representation of the number of embeddings in the + NestedTensors. + """ + + if isinstance(embeddings, torch.Tensor): + return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) + + return " + ".join( + _embedding_count_expression(inner) for inner in embeddings) + + +def merge_multimodal_embeddings(input_ids: torch.Tensor, + inputs_embeds: torch.Tensor, + multimodal_embeddings: NestedTensors, + placeholder_token_id: int) -> torch.Tensor: + """ + Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the + positions in ``inputs_embeds`` corresponding to placeholder tokens in + ``input_ids``. + + Note: + This updates ``inputs_embeds`` in place. + """ + mask = (input_ids == placeholder_token_id) + num_expected_tokens = mask.sum().item() + assert isinstance(num_expected_tokens, int) + + flattened = _flatten_embeddings(multimodal_embeddings) + if flattened.shape[0] != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {flattened.shape[0]} " + f"multimodal tokens to {num_expected_tokens} placeholders") + + inputs_embeds[mask] = flattened + return inputs_embeds + + +class LayerFn(Protocol): + + def __call__(self, prefix: str) -> torch.nn.Module: + ... + + +class PPMissingLayer(torch.nn.Identity): + """ + A placeholder layer for missing layers in a pipeline parallel model. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + +_CPU_OFFLOAD_BYTES = 0 +_CPU_OFFLOAD_MAX_BYTES = 0 + + +def set_cpu_offload_max_bytes(max_bytes: int) -> None: + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + _CPU_OFFLOAD_BYTES = 0 + _CPU_OFFLOAD_MAX_BYTES = max_bytes + + +def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: + device = next(module.parameters()).device + + if device == torch.device("cpu"): + return module + + global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + return module + + pin_memory = is_pin_memory_available() + + # offload parameters to CPU + # use pin_memory if possible, which helps cudagraph capture speed + offloaded_parameters = False + for p in module.parameters(): + if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: + # we use per-parameter offloading + # one module might have some parameters offloaded and some not + break + + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device='cpu', + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() + offloaded_parameters = True + + if offloaded_parameters: + original_forward = module.forward + + def forward(*args, **kwargs): + module.forward = original_forward + device_state = { + # here we blindly call `to(device)` + # if the parameter is already on the device, it will be a no-op + k: v.to(device, non_blocking=True) + for k, v in module.state_dict().items() + } + output = functional_call(module, + device_state, + args=args, + kwargs=kwargs) + module.forward = forward + return output + + module.forward = forward + + return module + + +def make_layers( + num_hidden_layers: int, + layer_fn: LayerFn, + prefix: str, +) -> Tuple[int, int, torch.nn.ModuleList]: + """Make a list of layers with the given layer function, taking + pipeline parallelism into account. + """ + from vllm.distributed.parallel_state import get_pp_group + from vllm.distributed.utils import get_pp_indices + start_layer, end_layer = get_pp_indices(num_hidden_layers, + get_pp_group().rank_in_group, + get_pp_group().world_size) + modules = torch.nn.ModuleList( + [PPMissingLayer() for _ in range(start_layer)] + [ + maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}")) + for idx in range(start_layer, end_layer) + ] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]) + return start_layer, end_layer, modules + + +# NOTE: don't use lru_cache here because it can prevent garbage collection +_model_to_pp_missing_layer_names: Dict[int, List[str]] = {} + + +def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]: + """Get the names of the missing layers in a pipeline parallel model.""" + model_id = id(model) + if model_id in _model_to_pp_missing_layer_names: + return _model_to_pp_missing_layer_names[model_id] + + missing_layer_names = [] + for name, module in model.named_modules(): + if isinstance(module, PPMissingLayer): + # NOTE: the trailing dot is used to match the prefix of the layer. + # without the dot, we could match a layer that is not missing, + # e.g., 'encoder.layer.1' would match 'encoder.layer.11' + missing_layer_names.append(name + '.') + _model_to_pp_missing_layer_names[model_id] = missing_layer_names + + return missing_layer_names + + +def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: + """Check if a parameter is missing in a pipeline parallel model.""" + if isinstance(model, PPMissingLayer): + return True + + return any( + name.startswith(missing_layer_name) + for missing_layer_name in get_pp_missing_layer_names(model)) + + +def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): + + def make_empty_intermediate_tensors( + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: + return IntermediateTensors({ + key: torch.zeros((batch_size, hidden_size), + dtype=dtype, + device=device) + for key in keys + }) + + return make_empty_intermediate_tensors + +def maybe_prefix(prefix: str, name: str) -> str: + """Add a prefix to a name if the prefix is non-empty. + + Args: + prefix: The prefix to add. If empty, no prefix will be added. + name: The name to potentially prefix. + + Returns: + The string "prefix.name" if prefix was non-empty, otherwise just "name". + """ + return name if not prefix else f"{prefix}.{name}" + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: List[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, (f"layer name {layer_name} should" + " only contain one integer") + return int_vals[0] + +class LLMWrapper(nn.Module): + """ + To align with the key names of LoRA trained with PEFT, we need to add an + additional layer to the llm's implementation. + """ + + def __init__(self, llm: nn.Module, name: str) -> None: + super().__init__() + self.model_name = name + setattr(self, name, llm) + + def __getattr__(self, key: str): + llm = super().__getattr__(self.model_name) + if key == self.model_name: + return llm + + return getattr(llm, key) + + # We need to explicitly override this + def __call__(self, *args: Any, **kwargs: Any) -> Any: + llm = super().__getattr__(self.model_name) + return llm(*args, **kwargs) \ No newline at end of file diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py new file mode 100644 index 0000000..3bded82 --- /dev/null +++ b/vllm/model_executor/models/xverse.py @@ -0,0 +1,404 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/xverse/XVERSE-7B/blob/main/modeling_xverse.py +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Xverse model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + + +class XverseMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate, _ = self.gate_up_proj(x) + x = self.act_fn(gate) + x, _ = self.down_proj(x) + return x + + +class XverseAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + # partition the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=bias, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class XverseDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = XverseAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=getattr(config, "bias", False), + cache_config=cache_config, + ) + self.mlp = XverseMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class XverseModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: XverseDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PretrainedConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = XverseModel(config, cache_config, quant_config) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if ("rotary_emb.inv_freq" in name + or "rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py new file mode 100644 index 0000000..7a6d7c9 --- /dev/null +++ b/vllm/model_executor/parameter.py @@ -0,0 +1,403 @@ +from fractions import Fraction +from typing import Callable, Optional, Union + +import torch +from torch.nn import Parameter + +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.logger import init_logger + +__all__ = [ + "BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter", + "ModelWeightParameter", "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter" +] + +logger = init_logger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _assert_and_load(self, loaded_weight: torch.Tensor): + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow(self.output_dim, + tp_rank * shard_size, shard_size) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + if isinstance( + self, + (PackedColumnParameter, + PackedvLLMParameter)) and self.packed_dim == self.output_dim: + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size) + + param_data = self.data + + tp_rank = get_tensor_model_parallel_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, + shard_size) + loaded_weight = loaded_weight.narrow(self.output_dim, + tp_rank * shard_size, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + if isinstance( + self, + (PackedColumnParameter, + PackedvLLMParameter)) and self.output_dim == self.packed_dim: + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size) + + param_data = self.data + tp_rank = get_tensor_model_parallel_rank() + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, + shard_size) + loaded_weight = loaded_weight.narrow(self.output_dim, + shard_id * shard_size, shard_size) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow(self.input_dim, + tp_rank * shard_size, shard_size) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs): + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id(self, loaded_weight: torch.Tensor, + shard_id: Union[str, int], **kwargs): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__(self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__(self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size) + + +def permute_param_layout_(param: BasevLLMParameter, input_dim: int, + output_dim: int, **kwargs) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2,\ + "permute_param_layout_ only supports 2D parameters when either "\ + "input_dim or output_dim is not set" + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None,\ + "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None,\ + "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) + if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert hasattr(param, "packed_dim") and\ + param.packed_dim == perm[kwargs["packed_dim"]],\ + "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, + marlin_tile_size): + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor, + marlin_tile_size): + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + if marlin_tile_size is not None: + return _adjust_shard_indexes_for_marlin( + shard_size=shard_size, + shard_offset=shard_offset, + marlin_tile_size=marlin_tile_size) + return shard_size, shard_offset diff --git a/vllm/model_executor/pooling_metadata.py b/vllm/model_executor/pooling_metadata.py new file mode 100644 index 0000000..b86cafc --- /dev/null +++ b/vllm/model_executor/pooling_metadata.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple + +import torch + +from vllm.pooling_params import PoolingParams +from vllm.utils import is_pin_memory_available + + +class PoolingMetadata: + """Metadata for pooling operations in the Pooler layer. + + This class holds the necessary information for pooling operations, + providing context for how to perform pooling and other related operations. + + Attributes: + seq_groups: List of (seq_ids, pooling_params). + seq_data: A mapping of sequence ID to additional sequence data. + prompt_lens: List of the lengths of each prompt. + """ + + def __init__( + self, + seq_groups: List[Tuple[List[int], PoolingParams]], + seq_data: Dict[int, Any], # Specific data related to sequences + prompt_lens: List[int], + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + + def __repr__(self) -> str: + return ("PoolingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens})") + + +@dataclass +class PoolingTensors: + """Tensors for pooling.""" + + prompt_lens: torch.Tensor + + @classmethod + def from_pooling_metadata( + cls, + pooling_metadata: "PoolingMetadata", + device: torch.device, + ) -> "PoolingTensors": + """ + Create PoolingTensors from PoolingMetadata. + + Args: + pooling_metadata: PoolingMetadata instance to convert. + device: Device to store the tensors. + """ + # Convert prompt lengths to tensor + pin_memory = is_pin_memory_available() + + prompt_lens_t = torch.tensor( + pooling_metadata.prompt_lens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + + return cls(prompt_lens=prompt_lens_t.to(device=device, + non_blocking=True), ) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py new file mode 100644 index 0000000..ee02368 --- /dev/null +++ b/vllm/model_executor/sampling_metadata.py @@ -0,0 +1,587 @@ +from array import array +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData, + SequenceGroupMetadata) +from vllm.utils import (PyObjectCache, async_tensor_h2d, + is_pin_memory_available, make_tensor_with_pad) + +_SAMPLING_EPS = 1e-5 + + +@dataclass +class SequenceGroupToSample: + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ----------------------| + # |-- query_len ---| + + # Sequence ids for the sequence group in a previous step. + seq_ids: List[int] + sampling_params: SamplingParams + # seq_id -> sequence data. + seq_data: Dict[int, SequenceData] + # The length of the sequence (all tokens seen in the past + new token to + # compute attention) of the sequence group. None if it is in a decode + # stage. + seq_len: Optional[int] + # The length of new query tokens to compute in the current step. None if it + # is in a decode stage. The length of query_len <= seq_len if chunked + # prefill is enabled. + query_len: Optional[int] + # A random number generator for sampling. + generator: Optional[torch.Generator] + # True if the sequence group is in prefill stage. False if it is in a + # decode stage. + is_prompt: bool + # Query token indices from logits. to compute prompt logprob. Empty if + # prompt logprob is not required. + prompt_logprob_indices: List[int] + # Sample token indices from logits. Empty if sampling is not required. + sample_indices: List[int] + + @property + def do_sample(self): + return len(self.sample_indices) > 0 + + def __post_init__(self): + if len(self.prompt_logprob_indices) > 0: + assert self.sampling_params.prompt_logprobs is not None + if self.is_prompt: + assert self.seq_len is not None + assert self.query_len is not None + + +def gen_seq_group_to_sample_builder(num_seqs: int): + return lambda: SequenceGroupToSample( + seq_ids=[0] * num_seqs, + sampling_params=None, + seq_data=None, # type: ignore + seq_len=0, + query_len=0, + generator=None, + is_prompt=True, + prompt_logprob_indices=[], + sample_indices=[], + ) + + +class SamplingMetadataCache: + """Used to cache SamplingMetadata objects between scheduler iterations""" + + def __init__(self): + self._seq_group_to_sample_cache: Dict[int, PyObjectCache] = {} + + def get_cached_seq_group_to_sample(self, num_seqs): + if num_seqs not in self._seq_group_to_sample_cache: + self._seq_group_to_sample_cache[num_seqs] = PyObjectCache( + gen_seq_group_to_sample_builder(num_seqs)) + + obj = self._seq_group_to_sample_cache[num_seqs].get_object() + return obj + + def reset(self): + for cache in self._seq_group_to_sample_cache.values(): + cache.reset() + + +class SamplingMetadata: + """Metadata for input sequences. Used in sampler. + + The usage is as follow; + ``` + hidden_states = execute_model(...) + logits = hidden_states[sampling_metadata.selected_token_indices] + sample(logits) + + def sample(logits): + # Use categorized_sample_indices for sampling.... + ``` + + Args: + seq_groups: List of batched sequence groups. + selected_token_indices: (num_query_tokens_to_logprob). Indices to find + logits from the initial model output hidden states. + categorized_sample_indices: SamplingType -> token indices to sample. + Each token indices is 2D tensor of (num_indices, num_indices) where + the first item means the sample index within the returned logit + (before pruning padding), and the second item means the sample + index after pruning using selected_token_indices. + For example, if the returned logit is [1, 2, 3], and we select + [1, 2] for sampling, the pruned logit will be [2, 3]. In this case, + The first tuple is [1, 2] (sampled index within original logit), + and the second tuple is [0, 1] (sampled index within pruned logit). + num_prompts: Number of prompt sequence groups in seq_groups. + skip_sampler_cpu_output: Indicates if we want to skip the GPU=>CPU + serialization of token outputs. + reuse_sampling_tensors: Indicates if we want to reuse sampling + tensors that are part of the sampler forward pass. Currently, + it is mainly used for multi-step decode. + + """ + + def __init__( + self, + seq_groups: List[SequenceGroupToSample], + selected_token_indices: torch.Tensor, + categorized_sample_indices: Dict[SamplingType, torch.Tensor], + num_prompts: int, + skip_sampler_cpu_output: bool = False, + reuse_sampling_tensors: bool = False, + ) -> None: + self.seq_groups = seq_groups + self.selected_token_indices = selected_token_indices + self.categorized_sample_indices = categorized_sample_indices + self.num_prompts = num_prompts + self.skip_sampler_cpu_output = skip_sampler_cpu_output + self.reuse_sampling_tensors = reuse_sampling_tensors + + @staticmethod + def prepare( + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: List[int], + device: str, + pin_memory: bool, + generators: Optional[Dict[str, torch.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, + ) -> "SamplingMetadata": + ( + seq_groups, + selected_token_indices, + categorized_sample_indices, + num_prompts, + ) = _prepare_seq_groups(seq_group_metadata_list, seq_lens, query_lens, + device, generators, cache) + selected_token_indices = async_tensor_h2d( + selected_token_indices, + dtype=torch.long, + target_device=device, + pin_memory=pin_memory, + ) + categorized_sample_indices = { + t: async_tensor_h2d( + seq_ids, + dtype=torch.int, + target_device=device, + pin_memory=pin_memory, + ) + for t, seq_ids in categorized_sample_indices.items() + } + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + num_prompts=num_prompts, + ) + return sampling_metadata + + def __repr__(self) -> str: + return ( + "SamplingMetadata(" + f"seq_groups={self.seq_groups}, " + f"selected_token_indices={self.selected_token_indices}, " + f"categorized_sample_indices={self.categorized_sample_indices}), ") + + +def _prepare_seq_groups( + seq_group_metadata_list: List[SequenceGroupMetadata], + seq_lens: List[int], + query_lens: List[int], + device: str, + generators: Optional[Dict[str, torch.Generator]] = None, + cache: Optional[SamplingMetadataCache] = None, +) -> Tuple[List[SequenceGroupToSample], List[int], Dict[SamplingType, + List[int]], int, ]: + """Prepare sequence groups and indices for sampling. + + Args: + seq_group_metadata_list: A list of sequence group to batch. + seq_lens: A list of sequence lens per sequence group. + Index of prompt len should match with seq_group_metadata_list. + query_lens: A list of query lengths. Prompt lens include the length + of entire prompt tokens, and it could be shorter. + device: A device to use for random number generators, + `SequenceGroupToSample.generator`. + generators: A store of per-request random number generators used + for seeded requests. + + Returns: + seq_groups: A list of sequence group to sample. + selected_token_indices: See the definition from `SamplingMetadata`. + categorized_sample_indices: See the definition from `SamplingMetadata`. + num_prompts: Total number of prompts from `seq_group_metadata_list`. + """ + # Batched sequence groups for the current model forward stsep. + seq_groups: List[SequenceGroupToSample] = [] + # A list of token indices to sample/compute logprob. It is used to + # prune the outcome logits from the model for the performance. + selected_token_indices: List[int] = [] + # Used for selected_token_indices. + model_output_idx = 0 + + # Sampling type -> ( + # indices to sample/prompt logprob within pruned output logits, + # indices to sample within pruned logits) + categorized_sample_indices: Dict[SamplingType, List[int]] = { + t: [] + for t in SamplingType + } + # Index of logits to compute logprob. Logits include both prompt logprob + # and sample logprob indices. + logit_idx = 0 + # Total number of prompts from given sequence groups. + num_prompts = 0 + + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = seq_group_metadata.seq_data.keys() + + if cache is not None: + sample_obj = cache.get_cached_seq_group_to_sample(len(seq_ids)) + + for j, seq_id in enumerate(seq_ids): + sample_obj.seq_ids[j] = seq_id + + sample_obj.prompt_logprob_indices.clear() + sample_obj.sample_indices.clear() + + sampling_params = seq_group_metadata.sampling_params + is_prompt = seq_group_metadata.is_prompt + generator: Optional[torch.Generator] = None + # If the current seq group is in decode stage, it is None. + seq_len: Optional[int] = None + query_len: Optional[int] = None + prompt_logprob_indices: List[int] = (sample_obj.prompt_logprob_indices + if cache is not None else []) + sample_indices: List[int] = (sample_obj.sample_indices + if cache is not None else []) + do_sample = seq_group_metadata.do_sample + + if seq_group_metadata.is_prompt: + if sampling_params.seed is not None: + generator = torch.Generator(device=device).manual_seed( + sampling_params.seed) + if generators is not None: + generators[seq_group_metadata.request_id] = generator + + num_prompts += 1 + num_prefill_sample = len(seq_ids) + assert num_prefill_sample == 1 + assert query_lens is not None and seq_lens is not None + query_len, seq_len = query_lens[i], seq_lens[i] + # If we need sampling, exclude num_prefill_sample tokens from + # prompt logprob. + prompt_logprob_len = (query_len - num_prefill_sample + if do_sample else query_len) + sample_len = num_prefill_sample if do_sample else 0 + else: + # Decode + prompt_logprob_len = 0 + query_len = query_lens[i] if query_lens is not None else 1 + sample_len = len(seq_ids) * query_len if do_sample else 0 + + if sampling_params.seed is not None and generators is not None: + generator = generators.get(seq_group_metadata.request_id) + + # Update indices to select from the model output. + """ + This blocks computes selected_token_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + """ + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + prompt_logprob_len)) + model_output_idx += prompt_logprob_len + if do_sample: + selected_token_indices.extend( + range(model_output_idx, model_output_idx + sample_len)) + model_output_idx += sample_len + + # We now find indices for logprob computation and sampling. + """ + This block computes categorized_sample_indices which is used in the + following way. + + hidden_states = model(...) + logits = hidden_states[selected_token_indices] + def sample(logits): + # Use categorized_sample_indices for sampling. + # prompt_logprob_indices to find prompt logprob indices. + # sample_indices to find sample indices. + """ + + if sampling_params.prompt_logprobs is not None: + prompt_logprob_indices.extend( + range(logit_idx, logit_idx + prompt_logprob_len)) + logit_idx += prompt_logprob_len + if do_sample: + sample_indices.extend(range(logit_idx, logit_idx + sample_len)) + categorized_sample_indices[sampling_params.sampling_type].extend( + list(range(logit_idx, logit_idx + sample_len))) + logit_idx += sample_len + + if cache is not None: + sample_obj.sampling_params = sampling_params + sample_obj.seq_data = seq_group_metadata.seq_data + sample_obj.seq_len = seq_len + sample_obj.query_len = query_len + sample_obj.generator = generator + sample_obj.is_prompt = is_prompt + else: + sample_obj = SequenceGroupToSample( + seq_ids=list(seq_ids), + sampling_params=sampling_params, + seq_data=seq_group_metadata.seq_data, + seq_len=seq_len, + query_len=query_len, + generator=generator, + is_prompt=is_prompt, + prompt_logprob_indices=list(prompt_logprob_indices), + sample_indices=list(sample_indices), + ) + + seq_groups.append(sample_obj) + + if cache is not None: + cache.reset() + + return (seq_groups, selected_token_indices, categorized_sample_indices, + num_prompts) + + +@dataclass +class SamplingTensors: + """Tensors for sampling.""" + + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + prompt_tokens: torch.Tensor + output_tokens: torch.Tensor + + @classmethod + def from_sampling_metadata( + cls, + sampling_metadata: "SamplingMetadata", + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> Tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: List[array] = [] + output_tokens: List[array] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + + assert sampling_metadata.seq_groups is not None + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + sampling_params = seq_group.sampling_params + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + + is_prompt = seq_group.is_prompt + if is_prompt and sampling_params.prompt_logprobs is not None: + # For tokens in the prompt that we only need to get + # their logprobs + query_len = seq_group.query_len + assert query_len is not None + prefill_len = len(seq_group.prompt_logprob_indices) + temperatures += [temperature] * prefill_len + top_ps += [top_p] * prefill_len + top_ks += [top_k] * prefill_len + min_ps += [min_p] * prefill_len + presence_penalties += [0] * prefill_len + frequency_penalties += [0] * prefill_len + repetition_penalties += [1] * prefill_len + + if seq_group.do_sample: + sample_lens = len(seq_group.sample_indices) + assert sample_lens >= len(seq_ids) + temperatures += [temperature] * sample_lens + top_ps += [top_p] * sample_lens + top_ks += [top_k] * sample_lens + min_ps += [min_p] * sample_lens + presence_penalties += [p] * sample_lens + frequency_penalties += [f] * sample_lens + repetition_penalties += [r] * sample_lens + + if do_penalties: + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + prefill_len = len(seq_group.prompt_logprob_indices) + prompt_tokens.extend( + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) + output_tokens.extend( + array(VLLM_TOKEN_ID_ARRAY_TYPE) + for _ in range(prefill_len)) + if seq_group.do_sample: + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids_array) + output_tokens.append(seq_data.output_token_ids_array) + + sampling_tensors = SamplingTensors.from_lists( + temperatures, + top_ps, + top_ks, + min_ps, + presence_penalties, + frequency_penalties, + repetition_penalties, + prompt_tokens, + output_tokens, + vocab_size, + device, + dtype, + ) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + + @classmethod + def from_lists( + cls, + temperatures: List[float], + top_ps: List[float], + top_ks: List[int], + min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[array], + output_tokens: List[array], + vocab_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> "SamplingTensors": + # Note that the performance will be very bad without + # pinned memory. + pin_memory = is_pin_memory_available() + + do_penalties = prompt_tokens or output_tokens + + if do_penalties: + prompt_t = make_tensor_with_pad( + prompt_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + output_t = make_tensor_with_pad( + output_tokens, + vocab_size, + device="cpu", + dtype=torch.int64, + pin_memory=pin_memory, + ) + else: + empty_tensor = torch.empty(0, device=device, dtype=torch.long) + prompt_t = empty_tensor + output_t = empty_tensor + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + min_ps_t = torch.tensor( + min_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + presence_penalties_t = torch.tensor( + presence_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + frequency_penalties_t = torch.tensor( + frequency_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + # Because the memory is pinned, we can do non-blocking + # transfer to device. + + return cls( + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + min_ps=min_ps_t.to(device=device, non_blocking=True), + presence_penalties=presence_penalties_t.to(device=device, + non_blocking=True), + frequency_penalties=frequency_penalties_t.to(device=device, + non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, + non_blocking=True), + prompt_tokens=prompt_t.to(device=device, non_blocking=True), + output_tokens=output_t.to(device=device, non_blocking=True), + ) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py new file mode 100644 index 0000000..d7eec81 --- /dev/null +++ b/vllm/model_executor/utils.py @@ -0,0 +1,31 @@ +"""Utils for model executor.""" +from typing import Any, Dict, Optional + +import torch + +from vllm.utils import seed_everything + + +def set_random_seed(seed: int) -> None: + seed_everything(seed) + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key), (f"Overwriting existing tensor attribute: {key}") + setattr(weight, key, value) diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py new file mode 100644 index 0000000..489e1e5 --- /dev/null +++ b/vllm/multimodal/__init__.py @@ -0,0 +1,24 @@ +from .base import (BatchedTensorInputs, MultiModalDataBuiltins, + MultiModalDataDict, MultiModalInputs, MultiModalPlugin, + NestedTensors) +from .registry import MultiModalRegistry + +MULTIMODAL_REGISTRY = MultiModalRegistry() +""" +The global :class:`~MultiModalRegistry` is used by model runners to +dispatch data processing according to its modality and the target model. + +See also: + :ref:`input_processing_pipeline` +""" + +__all__ = [ + "BatchedTensorInputs", + "MultiModalDataBuiltins", + "MultiModalDataDict", + "MultiModalInputs", + "MultiModalPlugin", + "NestedTensors", + "MULTIMODAL_REGISTRY", + "MultiModalRegistry", +] diff --git a/vllm/multimodal/__pycache__/__init__.cpython-310.pyc b/vllm/multimodal/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..6add128 Binary files /dev/null and b/vllm/multimodal/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/multimodal/__pycache__/audio.cpython-310.pyc b/vllm/multimodal/__pycache__/audio.cpython-310.pyc new file mode 100644 index 0000000..8b05e28 Binary files /dev/null and b/vllm/multimodal/__pycache__/audio.cpython-310.pyc differ diff --git a/vllm/multimodal/__pycache__/base.cpython-310.pyc b/vllm/multimodal/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..db52c36 Binary files /dev/null and b/vllm/multimodal/__pycache__/base.cpython-310.pyc differ diff --git a/vllm/multimodal/__pycache__/image.cpython-310.pyc b/vllm/multimodal/__pycache__/image.cpython-310.pyc new file mode 100644 index 0000000..0790982 Binary files /dev/null and b/vllm/multimodal/__pycache__/image.cpython-310.pyc differ diff --git a/vllm/multimodal/__pycache__/registry.cpython-310.pyc b/vllm/multimodal/__pycache__/registry.cpython-310.pyc new file mode 100644 index 0000000..98fe724 Binary files /dev/null and b/vllm/multimodal/__pycache__/registry.cpython-310.pyc differ diff --git a/vllm/multimodal/__pycache__/utils.cpython-310.pyc b/vllm/multimodal/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..9e3b55c Binary files /dev/null and b/vllm/multimodal/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/multimodal/__pycache__/video.cpython-310.pyc b/vllm/multimodal/__pycache__/video.cpython-310.pyc new file mode 100644 index 0000000..6a318a2 Binary files /dev/null and b/vllm/multimodal/__pycache__/video.cpython-310.pyc differ diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py new file mode 100644 index 0000000..04d7182 --- /dev/null +++ b/vllm/multimodal/audio.py @@ -0,0 +1,17 @@ +from vllm.inputs.registry import InputContext +from vllm.multimodal.base import MultiModalInputs, MultiModalPlugin + + +class AudioPlugin(MultiModalPlugin): + """Plugin for audio data.""" + + def get_data_key(self) -> str: + return "audio" + + def _default_input_mapper(self, ctx: InputContext, data: object, + **mm_processor_kwargs) -> MultiModalInputs: + raise NotImplementedError("There is no default audio input mapper") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + raise NotImplementedError( + "There is no default maximum multimodal tokens") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py new file mode 100644 index 0000000..84e71cb --- /dev/null +++ b/vllm/multimodal/base.py @@ -0,0 +1,368 @@ +import sys +from abc import ABC, abstractmethod +from collections import UserDict, defaultdict +from typing import (Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, + TypedDict, TypeVar, Union, cast, final) + +import numpy as np +import torch +import torch.types +from PIL import Image +from torch import nn +from typing_extensions import TypeAlias + +from vllm.config import ModelConfig +from vllm.inputs import InputContext +from vllm.logger import init_logger +from vllm.utils import (JSONTree, get_allowed_kwarg_only_overrides, is_list_of, + json_map_leaves, resolve_mm_processor_kwargs) + +logger = init_logger(__name__) + +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] +""" +Uses a list instead of a tensor if the dimensions of each element do not match. +""" + +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] +""" +A dictionary containing nested tensors which have been batched via +:meth:`MultiModalInputs.batch`. +""" + +if sys.version_info < (3, 9): + # UserDict cannot be subscripted + class _MultiModalInputsBase(UserDict): + pass +else: + + class _MultiModalInputsBase(UserDict[str, NestedTensors]): + pass + + +class MultiModalInputs(_MultiModalInputsBase): + """ + A dictionary that represents the keyword arguments to + :meth:`~torch.nn.Module.forward`. + """ + + @staticmethod + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: + """ + Recursively stacks lists of tensors when they all have the same shape. + """ + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors + + if isinstance(nested_tensors, np.ndarray): + return torch.from_numpy(nested_tensors) + + if isinstance(nested_tensors, (int, float)): + return torch.tensor(nested_tensors) + + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. + return stacked + + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ + + return torch.stack(tensors_) + + @staticmethod + def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: + """ + Batch multiple inputs together into a dictionary. + + The resulting dictionary has the same keys as the inputs. + If the corresponding value from each input is a tensor and they all + share the same shape, the output value is a single batched tensor; + otherwise, the output value is a list containing the original value + from each input. + """ + if len(inputs_list) == 0: + return {} + + item_lists: Dict[str, List[NestedTensors]] = defaultdict(list) + + for inputs in inputs_list: + # For models that supports multiple modalities (e.g. Qwen2-VL), + # different modalities will return different data keys, + # so batch() should skip the same key check. + + for k, v in inputs.items(): + item_lists[k].append(v) + + return { + k: MultiModalInputs._try_stack(item_list) + for k, item_list in item_lists.items() + } + + @staticmethod + def as_kwargs( + batched_inputs: BatchedTensorInputs, + *, + device: torch.types.Device, + ) -> BatchedTensorInputs: + json_inputs = cast(JSONTree[torch.Tensor], batched_inputs) + + json_mapped = json_map_leaves( + lambda x: x.to(device, non_blocking=True), + json_inputs, + ) + + return cast(BatchedTensorInputs, json_mapped) + + +_T = TypeVar("_T") + +MultiModalData: TypeAlias = Union[_T, List[_T]] +""" +Either a single data instance, or a list of data instances. + +The number of data instances allowed per modality is restricted by +`--limit-mm-per-prompt`. +""" + + +@final +class MultiModalDataBuiltins(TypedDict, total=False): + """Modality types that are predefined by vLLM.""" + + image: MultiModalData[Image.Image] + """The input image(s).""" + + audio: MultiModalData[Tuple[np.ndarray, Union[int, float]]] + """The input audio item(s) and corresponding sampling rate(s).""" + + +MultiModalDataDict = Union[MultiModalDataBuiltins, + Mapping[str, MultiModalData[object]]] +""" +A dictionary containing an item for each modality type to input. + +Note: + This dictionary also accepts modality keys defined outside + :class:`MultiModalDataBuiltins` as long as a customized plugin is registered + through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. + Read more on that :ref:`here `. +""" + +MultiModalInputMapper = Callable[[InputContext, MultiModalData[object]], + MultiModalInputs] +""" +Return a dictionary to be passed as keyword arguments to +:meth:`~torch.nn.Module.forward`. This is similar in concept to tokenizers +and processors in HuggingFace Transformers. + +If the data is not supported, throw :exc:`TypeError`. +""" + +MultiModalTokensCalc = Union[int, Callable[[InputContext], int]] +""" +Calculate the maximum number of multimodal tokens input to the language +model. This does not include tokens that correspond to the input text. +""" + +N = TypeVar("N", bound=Type[nn.Module]) + + +class MultiModalPlugin(ABC): + """ + Base class that defines data processing logic for a specific modality. + + In particular, we adopt a registry pattern to dispatch data processing + according to the model being used (considering that different models may + process the same data differently). This registry is in turn used by + :class:`~MultiModalRegistry` which acts at a higher level + (i.e., the modality of the data). + + See also: + :ref:`adding_multimodal_plugin` + """ + + def __init__(self) -> None: + self._input_mappers: Dict[Type[nn.Module], MultiModalInputMapper] = {} + self._max_mm_tokens: Dict[Type[nn.Module], MultiModalTokensCalc] = {} + + @abstractmethod + def get_data_key(self) -> str: + """ + Get the data key corresponding to the modality. + """ + raise NotImplementedError + + @abstractmethod + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[object], + **mm_processor_kwargs, + ) -> MultiModalInputs: + """ + Return a dictionary to be passed as keyword arguments to + :meth:`~torch.nn.Module.forward`. This is similar in concept to + tokenizers and processors in HuggingFace Transformers. + + If the data is not supported, throw :exc:`TypeError`. + """ + raise NotImplementedError + + def register_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper to a model class. + + When the model receives input data that matches the modality served by + this plugin (see :meth:`get_data_key`), the provided function is + invoked to transform the data into a dictionary of model inputs. + + If `None` is provided, then the default input mapper is used instead. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_mappers: + logger.warning( + "Model class %s already has an input mapper " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_mappers[model_cls] = mapper \ + or self._default_input_mapper + + return model_cls + + return wrapper + + def map_input(self, model_config: ModelConfig, + data: MultiModalData[object], + mm_processor_kwargs: Dict[str, Any]) -> MultiModalInputs: + """ + Transform the data into a dictionary of model inputs using the + input mapper registered for that model. + + The model is identified by ``model_config``. + + Raises: + TypeError: If the data type is not supported. + + See also: + - :ref:`input_processing_pipeline` + - :ref:`enabling_multimodal_inputs` + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + mapper = self._input_mappers.get(model_cls) + + if mapper is None: + raise KeyError(f"No input mapper in {self} is registered for " + f"model class {model_cls.__name__}.") + + # In the case of the default mapper, we have to get resource + # processor through its HuggingFace autoclass; since this goes + # through **kwargs, we can't inspect it the same way, so we allow + # drop mm_processor_kwargs based on signature inspection + # if we're using the default mapper. + # + # This should be safe in general due to the sanitation, since the + # transformers resource should filter unused kwargs anyway. + uses_default_mapper = mapper == self._default_input_mapper + mm_processor_kwargs = resolve_mm_processor_kwargs( + model_config.mm_processor_kwargs, + mm_processor_kwargs, + callable=mapper, + allow_var_kwargs=uses_default_mapper, + ) + return mapper(InputContext(model_config), data, **mm_processor_kwargs) + + @abstractmethod + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + """ + Calculate the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model. + """ + raise NotImplementedError + + def _validate_max_multimodal_tokens(self, max_mm_tokens: int): + if max_mm_tokens < 1: + raise ValueError("You should set the number of tokens to a " + f"positive integer. Found: {max_mm_tokens}") + + def register_max_multimodal_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of tokens, corresponding to a single + instance of multimodal data, that are passed to the language model + for a model class. + + If `None` is provided, then the default calculation is used instead. + + See also: + :ref:`enabling_multimodal_inputs` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._max_mm_tokens: + logger.warning( + "Model class %s already calculates maximum number of " + "tokens in %s. It is overwritten by the new one.", + model_cls, self) + + if isinstance(max_mm_tokens, int): + self._validate_max_multimodal_tokens(max_mm_tokens) + + self._max_mm_tokens[model_cls] = max_mm_tokens \ + or self._default_max_multimodal_tokens + + return model_cls + + return wrapper + + def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + If this registry is not applicable to the model, `0` is returned. + + The model is identified by ``model_config``. + + See also: + :ref:`enabling_multimodal_inputs` + """ + # Avoid circular import + from vllm.model_executor.model_loader import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + if model_cls not in self._input_mappers: + return 0 + + max_mm_tokens = self._max_mm_tokens.get(model_cls) + if max_mm_tokens is None: + raise KeyError(f"No maximum number of multi-modal tokens is given " + f"for model class {model_cls.__name__} in {self}.") + + if callable(max_mm_tokens): + mm_processor_kwargs = get_allowed_kwarg_only_overrides( + max_mm_tokens, overrides=model_config.mm_processor_kwargs) + max_mm_tokens = max_mm_tokens(InputContext(model_config), + **mm_processor_kwargs) + + self._validate_max_multimodal_tokens(max_mm_tokens) + + return max_mm_tokens diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py new file mode 100644 index 0000000..5f74bce --- /dev/null +++ b/vllm/multimodal/image.py @@ -0,0 +1,88 @@ +from functools import lru_cache +from typing import Any, Dict, Optional + +import torch +from PIL import Image +from transformers.image_processing_base import BatchFeature + +from vllm.config import ModelConfig +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.transformers_utils.processor import get_image_processor +from vllm.utils import is_list_of + +from .base import MultiModalData, MultiModalInputs, MultiModalPlugin + +logger = init_logger(__name__) + +cached_get_image_processor = lru_cache(get_image_processor) + + +class ImagePlugin(MultiModalPlugin): + """Plugin for image data.""" + + def get_data_key(self) -> str: + return "image" + + def _get_hf_image_processor( + self, + model_config: ModelConfig, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return cached_get_image_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) + + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[object], + **mm_processor_kwargs, + ) -> MultiModalInputs: + model_config = ctx.model_config + + # Processed by input processor + if isinstance(data, BatchFeature): + return MultiModalInputs(data.data) + + # PIL image + if isinstance(data, Image.Image) or is_list_of(data, Image.Image): + image_processor = self._get_hf_image_processor( + model_config, + mm_processor_kwargs, + ) + + if image_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + try: + # NOTE: It may make sense to forward the mm_processor_kwargs + # here too. For now, to keep it simple, we only allow it be + # used for the initialization call though, just in case the + # signatures of the preprocessor initializer don't match + # preprocess() + batch_data = image_processor \ + .preprocess(data, return_tensors="pt") \ + .data + except Exception: + logger.error( + "Failed to process image (%s) with the default mapper. " + "This is most likely an edge-case with this model's image " + "processor in transformers (type: %s), and not vLLM.", + data, + type(image_processor).__name__) + raise + + return MultiModalInputs(batch_data) + + # Image embedding + elif isinstance(data, torch.Tensor) or is_list_of(data, torch.Tensor): + return MultiModalInputs({"image_embeds": data}) + + raise TypeError(f"Invalid image type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 3000 diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py new file mode 100644 index 0000000..5e9b8bd --- /dev/null +++ b/vllm/multimodal/registry.py @@ -0,0 +1,243 @@ +import functools +from collections import UserDict +from typing import Any, Dict, Mapping, Optional, Sequence + +from vllm.config import ModelConfig +from vllm.logger import init_logger + +from .audio import AudioPlugin +from .base import (MultiModalDataDict, MultiModalInputMapper, MultiModalInputs, + MultiModalPlugin, MultiModalTokensCalc, NestedTensors) +from .image import ImagePlugin +from .video import VideoPlugin + +logger = init_logger(__name__) + + +class _MultiModalLimits(UserDict): + """ + Wraps `_limits_by_model` for a more informative error message + when attempting to access a model that does not exist. + """ + + def __getitem__(self, key: ModelConfig) -> Dict[str, int]: + try: + return super().__getitem__(key) + except KeyError as exc: + msg = (f"Cannot find `mm_limits` for model={key.model}. Did you " + "forget to call `init_mm_limits_per_prompt`?") + raise KeyError(msg) from exc + + +class MultiModalRegistry: + """ + A registry that dispatches data processing to the + :class:`~vllm.multimodal.MultiModalPlugin` for each modality. + """ + + DEFAULT_PLUGINS = (ImagePlugin(), AudioPlugin(), VideoPlugin()) + + def __init__( + self, + *, + plugins: Sequence[MultiModalPlugin] = DEFAULT_PLUGINS) -> None: + self._plugins = {p.get_data_key(): p for p in plugins} + + # This is used for non-multimodal models + self._disabled_limits_per_plugin = {k: 0 for k in self._plugins} + + self._limits_by_model = _MultiModalLimits() + + def register_plugin(self, plugin: MultiModalPlugin) -> None: + """ + Register a multi-modal plugin so it can be recognized by vLLM. + + See also: + :ref:`adding_multimodal_plugin` + """ + data_type_key = plugin.get_data_key() + + if data_type_key in self._plugins: + logger.warning( + "A plugin is already registered for data type %s, " + "and will be overwritten by the new plugin %s.", data_type_key, + plugin) + + self._plugins[data_type_key] = plugin + + def _get_plugin(self, data_type_key: str): + plugin = self._plugins.get(data_type_key) + if plugin is not None: + return plugin + + msg = f"Unknown multi-modal data type: {data_type_key}" + raise NotImplementedError(msg) + + def register_input_mapper( + self, + data_type_key: str, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper for a specific modality to a model class. + + See :meth:`MultiModalPlugin.register_input_mapper` for more details. + """ + return self._get_plugin(data_type_key).register_input_mapper(mapper) + + def register_image_input_mapper( + self, + mapper: Optional[MultiModalInputMapper] = None, + ): + """ + Register an input mapper for image data to a model class. + + See :meth:`MultiModalPlugin.register_input_mapper` for more details. + """ + return self.register_input_mapper("image", mapper) + + def map_input( + self, + model_config: ModelConfig, + data: MultiModalDataDict, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ) -> MultiModalInputs: + """ + Apply an input mapper to the data passed to the model. + + The data belonging to each modality is passed to the corresponding + plugin which in turn converts the data into into keyword arguments + via the input mapper registered for that model. + + See :meth:`MultiModalPlugin.map_input` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + merged_dict: Dict[str, NestedTensors] = {} + + for data_key, data_value in data.items(): + plugin = self._get_plugin(data_key) + + num_items = len(data_value) if isinstance(data_value, list) else 1 + max_items = self._limits_by_model[model_config][data_key] + if num_items > max_items: + raise ValueError( + f"You set {data_key}={max_items} (or defaulted to 1) in " + f"`--limit-mm-per-prompt`, but found {num_items} items " + "in the same prompt.") + + input_dict = plugin.map_input(model_config, data_value, + mm_processor_kwargs) + for input_key, input_tensor in input_dict.items(): + if input_key in merged_dict: + raise ValueError(f"The input mappers (keys={set(data)}) " + f"resulted in a conflicting keyword " + f"argument to `forward()`: {input_key}") + + merged_dict[input_key] = input_tensor + + return MultiModalInputs(merged_dict) + + def create_input_mapper(self, model_config: ModelConfig): + """ + Create an input mapper (see :meth:`map_input`) for a specific model. + """ + # NOTE - we currently make the assumption that if a model has multiple + # supported modalities, they take the same kwargs. For the default, + # this could be an issue in the future if it falls back to two HF + # resources and we can't inspect the signature easily since it's + # getting initialized through the autoclass. + # + # If this is a problem in the future, we should revisit it, but since + # it potentially introduces a lot of complexity for a currently + # uncommon case, we do not for simplicity of both use & implementation + return functools.partial(self.map_input, model_config) + + def register_max_multimodal_tokens( + self, + data_type_key: str, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of tokens, corresponding to a single + instance of multimodal data belonging to a specific modality, that are + passed to the language model for a model class. + """ + return self._get_plugin(data_type_key) \ + .register_max_multimodal_tokens(max_mm_tokens) + + def register_max_image_tokens( + self, + max_mm_tokens: Optional[MultiModalTokensCalc] = None, + ): + """ + Register the maximum number of image tokens, corresponding to a single + image, that are passed to the language model for a model class. + """ + return self.register_max_multimodal_tokens("image", max_mm_tokens) + + def get_max_multimodal_tokens(self, model_config: ModelConfig) -> int: + """ + Get the maximum number of multi-modal tokens + for profiling the memory usage of a model. + + See :meth:`MultiModalPlugin.get_max_multimodal_tokens` for more details. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + limits_per_plugin = self._limits_by_model[model_config] + + return sum((limits_per_plugin[key] * + plugin.get_max_multimodal_tokens(model_config)) + for key, plugin in self._plugins.items()) + + def init_mm_limits_per_prompt( + self, + model_config: ModelConfig, + ) -> None: + """ + Initialize the maximum number of multi-modal input instances for each + modality that are allowed per prompt for a model class. + """ + if model_config in self._limits_by_model: + logger.warning( + "`mm_limits` has already been set for model=%s, and will " + "be overwritten by the new values.", model_config.model) + + multimodal_config = model_config.multimodal_config + if multimodal_config is None: + limits_per_plugin = self._disabled_limits_per_plugin + else: + config_limits_per_plugin = multimodal_config.limit_per_prompt + + extra_keys = config_limits_per_plugin.keys() - self._plugins.keys() + if extra_keys: + logger.warning( + "Detected extra keys in `--limit-mm-per-prompt` which " + "are not registered as multi-modal plugins: %s. " + "They will be ignored.", extra_keys) + + # NOTE: Currently the default is set to 1 for each plugin + # TODO: Automatically determine the limits based on budget + # once more models support multi-image inputs + limits_per_plugin = { + key: config_limits_per_plugin.get(key, 1) + for key in self._plugins + } + + self._limits_by_model[model_config] = limits_per_plugin + + def get_mm_limits_per_prompt( + self, + model_config: ModelConfig, + ) -> Mapping[str, int]: + """ + Get the maximum number of multi-modal input instances for each modality + that are allowed per prompt for a model class. + + Note: + This should be called after :meth:`init_mm_limits_per_prompt`. + """ + return self._limits_by_model[model_config] diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py new file mode 100644 index 0000000..3c80146 --- /dev/null +++ b/vllm/multimodal/utils.py @@ -0,0 +1,323 @@ +import base64 +from functools import lru_cache +from io import BytesIO +from typing import Any, List, Optional, Tuple, TypeVar, Union + +import numpy as np +import numpy.typing as npt +from PIL import Image + +from vllm.connections import global_http_connection +from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT +from vllm.logger import init_logger +from vllm.multimodal.base import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +logger = init_logger(__name__) + +cached_get_tokenizer = lru_cache(get_tokenizer) + + +def _load_image_from_bytes(b: bytes): + image = Image.open(BytesIO(b)) + image.load() + return image + + +def _load_image_from_data_url(image_url: str): + # Only split once and assume the second part is the base64 encoded image + _, image_base64 = image_url.split(",", 1) + return load_image_from_base64(image_base64) + + +def fetch_image(image_url: str, *, image_mode: str = "RGB") -> Image.Image: + """ + Load a PIL image from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + if image_url.startswith('http'): + image_raw = global_http_connection.get_bytes( + image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT) + image = _load_image_from_bytes(image_raw) + + elif image_url.startswith('data:image'): + image = _load_image_from_data_url(image_url) + else: + raise ValueError("Invalid 'image_url': A valid 'image_url' must start " + "with either 'data:image' or 'http'.") + + return image.convert(image_mode) + + +async def async_fetch_image(image_url: str, + *, + image_mode: str = "RGB") -> Image.Image: + """ + Asynchronously load a PIL image from a HTTP or base64 data URL. + + By default, the image is converted into RGB format. + """ + if image_url.startswith('http'): + image_raw = await global_http_connection.async_get_bytes( + image_url, timeout=VLLM_IMAGE_FETCH_TIMEOUT) + image = _load_image_from_bytes(image_raw) + + elif image_url.startswith('data:image'): + image = _load_image_from_data_url(image_url) + else: + raise ValueError("Invalid 'image_url': A valid 'image_url' must start " + "with either 'data:image' or 'http'.") + + return image.convert(image_mode) + + +def try_import_audio_packages() -> Tuple[Any, Any]: + try: + import librosa + import soundfile + except ImportError: + raise ImportError( + "Please install vllm[audio] for audio support.") from None + return librosa, soundfile + + +def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: + """ + Load audio from a URL. + """ + librosa, _ = try_import_audio_packages() + + if audio_url.startswith("http"): + audio_bytes = global_http_connection.get_bytes( + audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) + elif audio_url.startswith("data:audio"): + _, audio_base64 = audio_url.split(",", 1) + audio_bytes = base64.b64decode(audio_base64) + else: + raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " + "with either 'data:audio' or 'http'.") + + return librosa.load(BytesIO(audio_bytes), sr=None) + + +async def async_fetch_audio( + audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: + """ + Asynchronously fetch audio from a URL. + """ + librosa, _ = try_import_audio_packages() + + if audio_url.startswith("http"): + audio_bytes = await global_http_connection.async_get_bytes( + audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) + elif audio_url.startswith("data:audio"): + _, audio_base64 = audio_url.split(",", 1) + audio_bytes = base64.b64decode(audio_base64) + else: + raise ValueError("Invalid 'audio_url': A valid 'audio_url' must start " + "with either 'data:audio' or 'http'.") + + return librosa.load(BytesIO(audio_bytes), sr=None) + + +def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: + audio, sr = fetch_audio(audio_url) + return {"audio": (audio, sr)} + + +def get_and_parse_image(image_url: str) -> MultiModalDataDict: + image = fetch_image(image_url) + return {"image": image} + + +async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: + audio, sr = await async_fetch_audio(audio_url) + return {"audio": (audio, sr)} + + +async def async_get_and_parse_image(image_url: str) -> MultiModalDataDict: + image = await async_fetch_image(image_url) + return {"image": image} + + +def encode_audio_base64( + audio: np.ndarray, + sampling_rate: int, +) -> str: + """Encode audio as base64.""" + _, soundfile = try_import_audio_packages() + + buffered = BytesIO() + soundfile.write(buffered, audio, sampling_rate, format="WAV") + + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +def encode_image_base64( + image: Image.Image, + *, + image_mode: str = "RGB", + format: str = "JPEG", +) -> str: + """ + Encode a pillow image to base64 format. + + By default, the image is converted into RGB format before being encoded. + """ + buffered = BytesIO() + image = image.convert(image_mode) + image.save(buffered, format) + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """Load image from base64 format.""" + return _load_image_from_bytes(base64.b64decode(image)) + + +def rescale_image_size(image: Image.Image, + size_factor: float, + transpose: int = -1) -> Image.Image: + """Rescale the dimensions of an image by a constant factor.""" + new_width = int(image.width * size_factor) + new_height = int(image.height * size_factor) + image = image.resize((new_width, new_height)) + if transpose >= 0: + image = image.transpose(Image.Transpose(transpose)) + return image + + +def try_import_video_packages() -> Any: + try: + import cv2 + except ImportError: + raise ImportError( + "Please install vllm[video] for video support.") from None + return cv2 + + +def resize_video(frames: npt.NDArray, size: Tuple[int, int]) -> npt.NDArray: + cv2 = try_import_video_packages() + + num_frames, _, _, channels = frames.shape + new_height, new_width = size + resized_frames = np.empty((num_frames, new_height, new_width, channels), + dtype=frames.dtype) + for i, frame in enumerate(frames): + resized_frame = cv2.resize(frame, (new_width, new_height)) + resized_frames[i] = resized_frame + return resized_frames + + +def rescale_video_size(frames: npt.NDArray, size_factor: float) -> npt.NDArray: + _, height, width, _ = frames.shape + new_height = int(height * size_factor) + new_width = int(width * size_factor) + + return resize_video(frames, (new_height, new_width)) + + +def sample_frames_from_video(frames: npt.NDArray, + num_frames: int) -> npt.NDArray: + total_frames = frames.shape[0] + if num_frames == -1: + return frames + else: + frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) + sampled_frames = frames[frame_indices, ...] + return sampled_frames + + +# Utilities for input processors +_T = TypeVar("_T", str, int) + + +def repeat_and_pad_token( + token: _T, + *, + repeat_count: int = 1, + pad_token_left: Optional[_T] = None, + pad_token_right: Optional[_T] = None, +) -> List[_T]: + replacement = [token] * repeat_count + if pad_token_left is not None: + replacement = [pad_token_left] + replacement + if pad_token_right is not None: + replacement = replacement + [pad_token_right] + + return replacement + + +def repeat_and_pad_placeholder_tokens( + tokenizer: AnyTokenizer, + prompt: Optional[str], + prompt_token_ids: List[int], + *, + placeholder_token_id: int, + repeat_count: Union[int, List[int]], + pad_token_left: Optional[int] = None, + pad_token_right: Optional[int] = None, +) -> Tuple[Optional[str], List[int]]: + if isinstance(repeat_count, int): + repeat_count = [repeat_count] + + if prompt is None: + new_prompt = None + else: + placeholder_token_str = tokenizer.decode(placeholder_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + + placeholder_token_count = prompt.count(placeholder_token_str) + # This is an arbitrary number to distinguish between the two cases + if placeholder_token_count > 16: + logger.warning( + "Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating %s tokens.", placeholder_token_str) + if placeholder_token_count < len(repeat_count): + logger.warning( + "The number of multi-modal placeholder tokens in the prompt " + "is less than the number of multi-modal inputs. Extra " + "placeholder tokens will be treated as plain text") + repeat_count = repeat_count[:placeholder_token_count] + + prompt_parts = prompt.split(placeholder_token_str, + maxsplit=len(repeat_count)) + new_prompt = "" + for i, repeat_count_item in enumerate(repeat_count): + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count_item, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace + new_prompt += prompt_parts[i] + replacement_str + new_prompt += prompt_parts[-1] + + new_token_ids: List[int] = [] + placeholder_token_idx = 0 + for i, token in enumerate(prompt_token_ids): + if token == placeholder_token_id: + replacement_ids = repeat_and_pad_token( + placeholder_token_id, + repeat_count=repeat_count[placeholder_token_idx], + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) + new_token_ids.extend(replacement_ids) + placeholder_token_idx += 1 + + # No need to further scan the list since we replaced all tokens + if placeholder_token_idx >= len(repeat_count): + new_token_ids.extend(prompt_token_ids[i + 1:]) + break + else: + new_token_ids.append(token) + + return new_prompt, new_token_ids diff --git a/vllm/multimodal/video.py b/vllm/multimodal/video.py new file mode 100644 index 0000000..4a9dbf2 --- /dev/null +++ b/vllm/multimodal/video.py @@ -0,0 +1,86 @@ +from functools import lru_cache +from typing import Any, Dict, List, Optional, Union + +import numpy as np + +from vllm.config import ModelConfig +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.transformers_utils.processor import get_video_processor +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.utils import is_list_of + +from .base import MultiModalData, MultiModalInputs +from .image import ImagePlugin + +logger = init_logger(__name__) + +cached_get_video_processor = lru_cache(get_video_processor) +cached_get_tokenizer = lru_cache(get_tokenizer) + +VideoInput = Union[ + "np.ndarray", # single video input + List["np.ndarray"], + # TODO: support more types + # List[Image.Image], List[List[Image.Image]], + # "torch.Tensor", + # List["torch.Tensor"], + # List[List["np.ndarrray"]], + # List[List["torch.Tensor"]], +] + + +class VideoPlugin(ImagePlugin): + """Plugin for video data.""" + + def get_data_key(self) -> str: + return "video" + + def _get_hf_video_processor( + self, + model_config: ModelConfig, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + ): + if mm_processor_kwargs is None: + mm_processor_kwargs = {} + return cached_get_video_processor( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + **mm_processor_kwargs) + + def _default_input_mapper( + self, + ctx: InputContext, + data: MultiModalData[object], + **mm_processor_kwargs, + ) -> MultiModalInputs: + model_config = ctx.model_config + + # single video input as np.ndarray + if isinstance(data, np.ndarray): + video_processor = self._get_hf_video_processor( + model_config, + mm_processor_kwargs, + ) + if video_processor is None: + raise RuntimeError("No HuggingFace processor is available " + "to process the image object") + try: + # NOTE: Similar to image; it may be a good idea to filter and + # pass mm_processor_kwargs here too, but for now we don't to + # avoid extra complexity if the initializer and preprocess + # signatures of the processor don't align + batch_data = video_processor(data, return_tensors="pt").data + except Exception: + logger.error("Failed to process image (%s)", data) + raise + + return MultiModalInputs(batch_data) + elif is_list_of(data, np.ndarray): + raise NotImplementedError( + "Multi video for a prompt is not supported yet") + + raise TypeError(f"Invalid video type: {type(data)}") + + def _default_max_multimodal_tokens(self, ctx: InputContext) -> int: + return 4096 diff --git a/vllm/outputs.py b/vllm/outputs.py new file mode 100644 index 0000000..0765024 --- /dev/null +++ b/vllm/outputs.py @@ -0,0 +1,318 @@ +import time +from dataclasses import dataclass +from typing import List, Optional +from typing import Sequence as GenericSequence +from typing import Union + +from vllm.lora.request import LoRARequest +from vllm.sampling_params import RequestOutputKind +from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, + SequenceGroup, SequenceStatus) + + +@dataclass +class CompletionOutput: + """The output data of one completion output of a request. + + Args: + index: The index of the output in the request. + text: The generated output text. + token_ids: The token IDs of the generated output text. + cumulative_logprob: The cumulative log probability of the generated + output text. + logprobs: The log probabilities of the top probability words at each + position if the logprobs are requested. + finish_reason: The reason why the sequence is finished. + stop_reason: The stop string or token id that caused the completion + to stop, None if the completion finished for some other reason + including encountering the EOS token. + lora_request: The LoRA request that was used to generate the output. + """ + + index: int + text: str + token_ids: GenericSequence[int] + cumulative_logprob: Optional[float] + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None + + def finished(self) -> bool: + return self.finish_reason is not None + + def __repr__(self) -> str: + return (f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"stop_reason={self.stop_reason})") + + +@dataclass +class EmbeddingOutput: + """The output data of one completion output of a request. + + Args: + embedding: The embedding vector, which is a list of floats. The + length of vector depends on the model as listed in the embedding guide. + """ + + embedding: List[float] + + def __repr__(self) -> str: + return (f"EmbeddingOutput(" + f"embedding={len(self.embedding)})") + + +class RequestOutput: + """The output data of a completion request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + For encoder/decoder models, this is the + decoder input prompt. + prompt_token_ids: The token IDs of the prompt. + For encoder/decoder models, this is the + decoder input prompt token ids. + prompt_logprobs: The log probabilities to return per prompt token. + outputs: The output sequences of the request. + finished: Whether the whole request is finished. + metrics: Metrics associated with the request. + lora_request: The LoRA request that was used to generate the output. + encoder_prompt: The encoder prompt string of the request; + None if decoder-only + encoder_prompt_token_ids: The token IDs of the encoder prompt; + None if decoder-only + """ + + def __init__( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]], + prompt_logprobs: Optional[PromptLogprobs], + outputs: List[CompletionOutput], + finished: bool, + metrics: Optional[RequestMetrics] = None, + lora_request: Optional[LoRARequest] = None, + encoder_prompt: Optional[str] = None, + encoder_prompt_token_ids: Optional[List[int]] = None, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.prompt_logprobs = prompt_logprobs + self.outputs = outputs + self.finished = finished + self.metrics = metrics + self.lora_request = lora_request + self.encoder_prompt = encoder_prompt + self.encoder_prompt_token_ids = encoder_prompt_token_ids + + @classmethod + def from_seq_group(cls, seq_group: SequenceGroup, + use_cache: bool) -> Optional["RequestOutput"]: + sampling_params = seq_group.sampling_params + if sampling_params is None: + raise ValueError( + "Sampling parameters are missing for a CompletionRequest.") + + finished = seq_group.is_finished() + if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( + not finished): + return None + + # Init cache (if needed) + if use_cache and seq_group.cached_request_output is None: + seq_group.cached_request_output = RequestOutput( # type: ignore + request_id="", + prompt=None, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[], + finished=False) + + seqs = seq_group.get_seqs() + if len(seqs) == 1: + top_n_seqs = seqs + else: + # Get the top-n sequences. + n = sampling_params._real_n or sampling_params.n + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + + # Create the outputs. + # NOTE: We need omit logprobs here explicitly because the sequence + # always has the logprobs of the sampled tokens even if the + # logprobs are not requested. + include_logprobs = sampling_params.logprobs is not None + text_buffer_length = sampling_params.output_text_buffer_length + delta = sampling_params.output_kind == RequestOutputKind.DELTA + + outputs = [] + include_prompt = True + for i, seq in enumerate(top_n_seqs): + output_text = seq.get_output_text_to_return( + text_buffer_length, delta) + + output_token_ids = seq.get_output_token_ids_to_return(delta) + num_output_tokens = 1 if isinstance(output_token_ids, + int) else len(output_token_ids) + + output_logprobs = seq.output_logprobs if include_logprobs else None + + if delta: + # Slice logprobs delta if applicable + if output_logprobs: + output_logprobs = output_logprobs[-num_output_tokens:] + # Don't include prompt if this is after the first output + # containing decode token ids + if include_prompt and seq.get_output_len() > num_output_tokens: + include_prompt = False + + if use_cache: + # Get cached output object + cached_outputs = seq_group.cached_request_output.outputs # type: ignore + if i >= len(cached_outputs): + cached_outputs.append( + CompletionOutput(index=i, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + stop_reason=None)) + output = cached_outputs[i] + + # Init cached output object + assert output.index == i + output.text = output_text + + if isinstance(output_token_ids, int): + output.token_ids.clear() + output.token_ids.append(output_token_ids) + else: + output.token_ids = output_token_ids + + output.cumulative_logprob = seq.get_cumulative_logprob() \ + if include_logprobs else None + output.logprobs = output_logprobs + output.finish_reason = SequenceStatus.get_finished_reason( + seq.status) + output.stop_reason = seq.stop_reason + + else: + output = CompletionOutput( + seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, + seq.get_cumulative_logprob() if include_logprobs else None, + output_logprobs, + SequenceStatus.get_finished_reason(seq.status), + seq.stop_reason) + + outputs.append(output) + + # Every sequence in the sequence group should have the same prompt. + if include_prompt: + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + encoder_prompt = seq_group.encoder_prompt + encoder_prompt_token_ids = seq_group.encoder_prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + else: + prompt = None + prompt_token_ids = None + encoder_prompt = None + encoder_prompt_token_ids = None + prompt_logprobs = None + finished_time = time.time() if finished else None + seq_group.set_finished_time(finished_time) + + init_args = (seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.lora_request, encoder_prompt, + encoder_prompt_token_ids) + + if use_cache: + request_output = seq_group.cached_request_output + request_output.__init__(*init_args) # type: ignore + + else: + request_output = cls(*init_args) + + return request_output + + def __repr__(self) -> str: + return (f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"encoder_prompt={self.encoder_prompt!r}, " + f"encoder_prompt_token_ids={self.encoder_prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished}, " + f"metrics={self.metrics}, " + f"lora_request={self.lora_request})") + + +class EmbeddingRequestOutput: + """ + The output data of an embedding request to the LLM. + + Args: + request_id (str): A unique identifier for the embedding request. + outputs (EmbeddingOutput): The embedding results for the given input. + prompt_token_ids (List[int]): A list of token IDs used in the prompt. + finished (bool): A flag indicating whether the embedding is completed. + """ + + def __init__(self, request_id: str, outputs: "EmbeddingOutput", + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + @classmethod + def from_seq_group(cls, + seq_group: 'SequenceGroup') -> "EmbeddingRequestOutput": + if seq_group.embeddings is None: + raise ValueError( + "Embeddings are missing in seq_group for EmbeddingRequest.") + output = EmbeddingOutput(seq_group.embeddings) + prompt_token_ids = seq_group.prompt_token_ids + finished = seq_group.is_finished() + + return cls(seq_group.request_id, output, prompt_token_ids, finished) + + def __repr__(self): + """ + Returns a string representation of an EmbeddingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the EmbeddingRequestOutput instance. + """ + return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") + + +class RequestOutputFactory: + + @staticmethod + def create(seq_group: SequenceGroup, use_cache: bool = False): + # Determine the type based on a condition, for example: + if hasattr(seq_group, + 'embeddings') and seq_group.embeddings is not None: + return EmbeddingRequestOutput.from_seq_group(seq_group) + else: + return RequestOutput.from_seq_group(seq_group, use_cache) diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py new file mode 100644 index 0000000..65f02ba --- /dev/null +++ b/vllm/platforms/__init__.py @@ -0,0 +1,81 @@ +from .interface import Platform, PlatformEnum, UnspecifiedPlatform + +current_platform: Platform + +# NOTE: we don't use `torch.version.cuda` / `torch.version.hip` because +# they only indicate the build configuration, not the runtime environment. +# For example, people can install a cuda build of pytorch but run on tpu. + +is_tpu = False +try: + # While it's technically possible to install libtpu on a non-TPU machine, + # this is a very uncommon scenario. Therefore, we assume that libtpu is + # installed if and only if the machine has TPUs. + import libtpu # noqa: F401 + is_tpu = True +except Exception: + pass + +is_cuda = False + +try: + import pynvml + pynvml.nvmlInit() + try: + if pynvml.nvmlDeviceGetCount() > 0: + is_cuda = True + finally: + pynvml.nvmlShutdown() +except Exception: + is_cuda = True + +is_rocm = False + +try: + import amdsmi + amdsmi.amdsmi_init() + try: + if len(amdsmi.amdsmi_get_processor_handles()) > 0: + is_rocm = True + finally: + amdsmi.amdsmi_shut_down() +except Exception: + pass + +is_xpu = False + +try: + import torch + if hasattr(torch, 'xpu') and torch.xpu.is_available(): + is_xpu = True +except Exception: + pass + +is_cpu = False +try: + from importlib.metadata import version + is_cpu = "cpu" in version("vllm") +except Exception: + pass + +if is_tpu: + # people might install pytorch built with cuda but run on tpu + # so we need to check tpu first + from .tpu import TpuPlatform + current_platform = TpuPlatform() +elif is_cuda: + from .cuda import CudaPlatform + current_platform = CudaPlatform() +elif is_rocm: + from .rocm import RocmPlatform + current_platform = RocmPlatform() +elif is_xpu: + from .xpu import XPUPlatform + current_platform = XPUPlatform() +elif is_cpu: + from .cpu import CpuPlatform + current_platform = CpuPlatform() +else: + current_platform = UnspecifiedPlatform() + +__all__ = ['Platform', 'PlatformEnum', 'current_platform'] diff --git a/vllm/platforms/__pycache__/__init__.cpython-310.pyc b/vllm/platforms/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..07fb96c Binary files /dev/null and b/vllm/platforms/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/platforms/__pycache__/cpu.cpython-310.pyc b/vllm/platforms/__pycache__/cpu.cpython-310.pyc new file mode 100644 index 0000000..0dcd79b Binary files /dev/null and b/vllm/platforms/__pycache__/cpu.cpython-310.pyc differ diff --git a/vllm/platforms/__pycache__/cuda.cpython-310.pyc b/vllm/platforms/__pycache__/cuda.cpython-310.pyc new file mode 100644 index 0000000..559a358 Binary files /dev/null and b/vllm/platforms/__pycache__/cuda.cpython-310.pyc differ diff --git a/vllm/platforms/__pycache__/interface.cpython-310.pyc b/vllm/platforms/__pycache__/interface.cpython-310.pyc new file mode 100644 index 0000000..9e971de Binary files /dev/null and b/vllm/platforms/__pycache__/interface.cpython-310.pyc differ diff --git a/vllm/platforms/__pycache__/rocm.cpython-310.pyc b/vllm/platforms/__pycache__/rocm.cpython-310.pyc new file mode 100644 index 0000000..33cda23 Binary files /dev/null and b/vllm/platforms/__pycache__/rocm.cpython-310.pyc differ diff --git a/vllm/platforms/__pycache__/tpu.cpython-310.pyc b/vllm/platforms/__pycache__/tpu.cpython-310.pyc new file mode 100644 index 0000000..84d31fe Binary files /dev/null and b/vllm/platforms/__pycache__/tpu.cpython-310.pyc differ diff --git a/vllm/platforms/__pycache__/xpu.cpython-310.pyc b/vllm/platforms/__pycache__/xpu.cpython-310.pyc new file mode 100644 index 0000000..87134d7 Binary files /dev/null and b/vllm/platforms/__pycache__/xpu.cpython-310.pyc differ diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py new file mode 100644 index 0000000..5243f59 --- /dev/null +++ b/vllm/platforms/cpu.py @@ -0,0 +1,20 @@ +import psutil +import torch + +from .interface import Platform, PlatformEnum + + +class CpuPlatform(Platform): + _enum = PlatformEnum.CPU + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return "cpu" + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + return psutil.virtual_memory().total + + @classmethod + def inference_mode(cls): + return torch.no_grad() diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py new file mode 100644 index 0000000..991ac38 --- /dev/null +++ b/vllm/platforms/cuda.py @@ -0,0 +1,145 @@ +"""Code inside this file can safely assume cuda platform, e.g. importing +pynvml. However, it should not initialize cuda context. +""" + +import os +from functools import lru_cache, wraps +from typing import Callable, List, Tuple, TypeVar + +import pynvml +from typing_extensions import ParamSpec + +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum + +logger = init_logger(__name__) + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +if pynvml.__file__.endswith("__init__.py"): + logger.warning( + "You are using a deprecated `pynvml` package. Please install" + " `nvidia-ml-py` instead, and make sure to uninstall `pynvml`." + " When both of them are installed, `pynvml` will take precedence" + " and cause errors. See https://pypi.org/project/pynvml " + "for more information.") + +# NVML utils +# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using NVML is that it will not initialize CUDA + + +def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: + + @wraps(fn) + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: + pynvml.nvmlInit() + try: + return fn(*args, **kwargs) + finally: + pynvml.nvmlShutdown() + + return wrapper + + +@lru_cache(maxsize=8) +def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]: + return [9, 0] + + +@lru_cache(maxsize=8) +def get_physical_device_name(device_id: int = 0) -> str: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return pynvml.nvmlDeviceGetName(handle) + + +@lru_cache(maxsize=8) +# @with_nvml_context +def get_physical_device_total_memory(device_id: int = 0) -> int: + handle = pynvml.nvmlDeviceGetHandleByIndex(device_id) + return int(pynvml.nvmlDeviceGetMemoryInfo(handle).total) + + +# @with_nvml_context +def warn_if_different_devices(): + device_ids: int = pynvml.nvmlDeviceGetCount() + if device_ids > 1: + device_names = [get_physical_device_name(i) for i in range(device_ids)] + if len(set(device_names)) > 1 and os.environ.get( + "CUDA_DEVICE_ORDER") != "PCI_BUS_ID": + logger.warning( + "Detected different devices in the system: \n%s\nPlease" + " make sure to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` to " + "avoid unexpected behavior.", "\n".join(device_names)) + + +try: + from sphinx.ext.autodoc.mock import _MockModule + + if not isinstance(pynvml, _MockModule): + warn_if_different_devices() +except ModuleNotFoundError: + pass + # warn_if_different_devices() + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string," + " which means GPU support is disabled.") + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + + +class CudaPlatform(Platform): + _enum = PlatformEnum.CUDA + + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + physical_device_id = device_id_to_physical_device_id(device_id) + major, minor = get_physical_device_capability(physical_device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + physical_device_id = device_id_to_physical_device_id(device_id) + return get_physical_device_name(physical_device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + return 32 * 1024 * 1024 * 1024 + # physical_device_id = device_id_to_physical_device_id(device_id) + # return get_physical_device_total_memory(physical_device_id) + + @classmethod + # @with_nvml_context + def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: + """ + query if the set of gpus are fully connected by nvlink (1 hop) + """ + handles = [ + pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError as error: + logger.error( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped.", + exc_info=error) + return False + return True diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py new file mode 100644 index 0000000..00742a2 --- /dev/null +++ b/vllm/platforms/interface.py @@ -0,0 +1,108 @@ +import enum +from typing import NamedTuple, Optional, Tuple, Union + +import torch + + +class PlatformEnum(enum.Enum): + CUDA = enum.auto() + ROCM = enum.auto() + TPU = enum.auto() + XPU = enum.auto() + CPU = enum.auto() + UNSPECIFIED = enum.auto() + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform: + _enum: PlatformEnum + + def is_cuda(self) -> bool: + return self._enum == PlatformEnum.CUDA + + def is_rocm(self) -> bool: + return self._enum == PlatformEnum.ROCM + + def is_tpu(self) -> bool: + return self._enum == PlatformEnum.TPU + + def is_xpu(self) -> bool: + return self._enum == PlatformEnum.XPU + + def is_cpu(self) -> bool: + return self._enum == PlatformEnum.CPU + + def is_cuda_alike(self) -> bool: + """Stateless version of :func:`torch.cuda.is_available`.""" + return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) + + @classmethod + def get_device_capability( + cls, + device_id: int = 0, + ) -> Optional[DeviceCapability]: + """Stateless version of :func:`torch.cuda.get_device_capability`.""" + return None + + @classmethod + def has_device_capability( + cls, + capability: Union[Tuple[int, int], int], + device_id: int = 0, + ) -> bool: + """ + Test whether this platform is compatible with a device capability. + + The ``capability`` argument can either be: + + - A tuple ``(major, minor)``. + - An integer ````. (See :meth:`DeviceCapability.to_int`) + """ + current_capability = cls.get_device_capability(device_id=device_id) + if current_capability is None: + return False + + if isinstance(capability, tuple): + return current_capability >= capability + + return current_capability.to_int() >= capability + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + """Get the name of a device.""" + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + """Get the total memory of a device in bytes.""" + raise NotImplementedError + + @classmethod + def inference_mode(cls): + """A device-specific wrapper of `torch.inference_mode`. + + This wrapper is recommended because some hardware backends such as TPU + do not support `torch.inference_mode`. In such a case, they will fall + back to `torch.no_grad` by overriding this method. + """ + return torch.inference_mode(mode=True) + + +class UnspecifiedPlatform(Platform): + _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py new file mode 100644 index 0000000..fd8afc9 --- /dev/null +++ b/vllm/platforms/rocm.py @@ -0,0 +1,36 @@ +import os +from functools import lru_cache + +import torch + +from vllm.logger import init_logger + +from .interface import DeviceCapability, Platform, PlatformEnum + +logger = init_logger(__name__) + +if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: + logger.warning("`fork` method is not supported by ROCm. " + "VLLM_WORKER_MULTIPROC_METHOD is overridden to" + " `spawn` instead.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +class RocmPlatform(Platform): + _enum = PlatformEnum.ROCM + + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.cuda.get_device_properties(device_id) + return device_props.total_memory diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py new file mode 100644 index 0000000..8ba973b --- /dev/null +++ b/vllm/platforms/tpu.py @@ -0,0 +1,33 @@ +import os + +import torch + +import vllm.envs as envs +from vllm.compilation.levels import CompilationLevel +from vllm.plugins import set_torch_compile_backend + +from .interface import Platform, PlatformEnum + +if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ: + os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE) + +assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\ + "TPU does not support Inductor." + +set_torch_compile_backend("openxla") + + +class TpuPlatform(Platform): + _enum = PlatformEnum.TPU + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + raise NotImplementedError + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + raise NotImplementedError + + @classmethod + def inference_mode(cls): + return torch.no_grad() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py new file mode 100644 index 0000000..d00e0dc --- /dev/null +++ b/vllm/platforms/xpu.py @@ -0,0 +1,22 @@ +import torch + +from .interface import DeviceCapability, Platform, PlatformEnum + + +class XPUPlatform(Platform): + _enum = PlatformEnum.XPU + + @staticmethod + def get_device_capability(device_id: int = 0) -> DeviceCapability: + major, minor, *_ = torch.xpu.get_device_capability( + device_id)['version'].split('.') + return DeviceCapability(major=int(major), minor=int(minor)) + + @staticmethod + def get_device_name(device_id: int = 0) -> str: + return torch.xpu.get_device_name(device_id) + + @classmethod + def get_device_total_memory(cls, device_id: int = 0) -> int: + device_props = torch.xpu.get_device_properties(device_id) + return device_props.total_memory diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py new file mode 100644 index 0000000..211fedb --- /dev/null +++ b/vllm/plugins/__init__.py @@ -0,0 +1,56 @@ +import logging +from typing import Callable, Dict, Optional, Union + +import vllm.envs as envs + +logger = logging.getLogger(__name__) + + +def load_general_plugins(): + """WARNING: plugins can be loaded for multiple times in different + processes. They should be designed in a way that they can be loaded + multiple times without causing issues. + """ + import sys + if sys.version_info < (3, 10): + from importlib_metadata import entry_points + else: + from importlib.metadata import entry_points + + allowed_plugins = envs.VLLM_PLUGINS + + discovered_plugins = entry_points(group='vllm.general_plugins') + for plugin in discovered_plugins: + logger.info("Found general plugin: %s", plugin.name) + if allowed_plugins is None or plugin.name in allowed_plugins: + try: + func = plugin.load() + func() + logger.info("Loaded general plugin: %s", plugin.name) + except Exception: + logger.exception("Failed to load general plugin: %s", + plugin.name) + + +_torch_compile_backend: Optional[Union[Callable, str]] = None + + +def set_torch_compile_backend(backend: Union[Callable, str]): + global _torch_compile_backend + _torch_compile_backend = backend + + +def get_torch_compile_backend() -> Optional[Union[Callable, str]]: + return _torch_compile_backend + + +_inductor_additional_configs: Dict = {} + + +def set_inductor_additional_configs(configs: Dict): + global _inductor_additional_configs + _inductor_additional_configs = configs + + +def get_inductor_additional_configs() -> Dict: + return _inductor_additional_configs diff --git a/vllm/plugins/__pycache__/__init__.cpython-310.pyc b/vllm/plugins/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..fb5aa1d Binary files /dev/null and b/vllm/plugins/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/pooling_params.py b/vllm/pooling_params.py new file mode 100644 index 0000000..7461fb5 --- /dev/null +++ b/vllm/pooling_params.py @@ -0,0 +1,23 @@ +from typing import Any, Optional + +import msgspec + + +class PoolingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """Pooling parameters for pooling. + + Attributes: + additional_data: Any additional data needed for pooling. + """ + additional_data: Optional[Any] = None + + def clone(self) -> "PoolingParams": + """Returns a deep copy of the PoolingParams instance.""" + return PoolingParams(additional_data=self.additional_data, ) + + def __repr__(self) -> str: + return (f"PoolingParams(" + f"additional_metadata={self.additional_data})") diff --git a/vllm/prompt_adapter/__init__.py b/vllm/prompt_adapter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/prompt_adapter/__pycache__/__init__.cpython-310.pyc b/vllm/prompt_adapter/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..b77191e Binary files /dev/null and b/vllm/prompt_adapter/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/prompt_adapter/__pycache__/layers.cpython-310.pyc b/vllm/prompt_adapter/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000..abe0d5f Binary files /dev/null and b/vllm/prompt_adapter/__pycache__/layers.cpython-310.pyc differ diff --git a/vllm/prompt_adapter/__pycache__/models.cpython-310.pyc b/vllm/prompt_adapter/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000..fb21ff6 Binary files /dev/null and b/vllm/prompt_adapter/__pycache__/models.cpython-310.pyc differ diff --git a/vllm/prompt_adapter/__pycache__/request.cpython-310.pyc b/vllm/prompt_adapter/__pycache__/request.cpython-310.pyc new file mode 100644 index 0000000..4cbd187 Binary files /dev/null and b/vllm/prompt_adapter/__pycache__/request.cpython-310.pyc differ diff --git a/vllm/prompt_adapter/__pycache__/utils.cpython-310.pyc b/vllm/prompt_adapter/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..a143fd4 Binary files /dev/null and b/vllm/prompt_adapter/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/prompt_adapter/__pycache__/worker_manager.cpython-310.pyc b/vllm/prompt_adapter/__pycache__/worker_manager.cpython-310.pyc new file mode 100644 index 0000000..3cc1bba Binary files /dev/null and b/vllm/prompt_adapter/__pycache__/worker_manager.cpython-310.pyc differ diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py new file mode 100644 index 0000000..27a61e6 --- /dev/null +++ b/vllm/prompt_adapter/layers.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import PromptAdapterConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + + +@dataclass +class PromptAdapterMapping(AdapterMapping): + pass + + +class VocabParallelEmbeddingWithPromptAdapter(nn.Module): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.emb_layer = self.base_layer + if 'LoRA' in base_layer.__class__.__name__: + self.emb_layer = self.base_layer.base_layer + + def create_prompt_adapter_weights( + self, prompt_adapter_config: PromptAdapterConfig): + self.embeddings_tensors = torch.zeros( + ( + prompt_adapter_config.max_prompt_adapters, + prompt_adapter_config.max_prompt_adapter_token, + self.emb_layer.embedding_dim, + ), + dtype=self.emb_layer.weight.dtype, + device=self.emb_layer.weight.device, + ) + self.adapter_lengths = torch.zeros( + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, + device=self.emb_layer.weight.device) + + self.indices_gpu: torch.Tensor + self.embedding_indices_gpu: torch.Tensor + + def reset_prompt_adapter(self, index: int): + self.embeddings_tensors[index] = 0 + + def set_prompt_adapter( + self, + index: int, + adapter_model: Optional[torch.Tensor], + ): + self.reset_prompt_adapter(index) + if adapter_model is not None: + length = adapter_model.shape[0] + self.embeddings_tensors[index, :length] = adapter_model + self.adapter_lengths[index] = length + + def set_mapping( + self, + prompt_indices: torch.Tensor, + prompt_embedding_indices: torch.Tensor, + ): + self.indices_gpu = prompt_indices.to( + device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to( + device=self.emb_layer.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = self.base_layer(x) + if self.embedding_indices_gpu.ndim > 1: + valid_mask = self.indices_gpu != -1 + gathered_embeddings = self.embeddings_tensors[ + self.embedding_indices_gpu[:, 0], + self.embedding_indices_gpu[:, 1]] + + # Update hidden states + hidden_states[valid_mask] = gathered_embeddings + return hidden_states \ No newline at end of file diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py new file mode 100644 index 0000000..18a5f86 --- /dev/null +++ b/vllm/prompt_adapter/models.py @@ -0,0 +1,355 @@ +import logging +import math +from typing import Any, Callable, Dict, List, Optional, Type + +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.layers import ( + VocabParallelEmbeddingWithPromptAdapter) # yapf: disable +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.utils import load_peft_weights + +logger = logging.getLogger(__name__) + +_GLOBAL_PROMPT_ADAPTER_ID = 0 + + +def get_prompt_adapter_id(): + global _GLOBAL_PROMPT_ADAPTER_ID + _GLOBAL_PROMPT_ADAPTER_ID += 1 + return _GLOBAL_PROMPT_ADAPTER_ID + + +def convert_to_embedding_indices(indices): + embedding_indices = [] + count = 0 + + for value in indices: + if value == -1: + count = 0 + else: + embedding_indices.append([value, count]) + count += 1 + + return torch.tensor(embedding_indices) + + +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], +) -> torch.Tensor: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a + batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter + ids to PromptAdapter indices. + + Returns: + pa_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + """ + id_to_index = { + id_: idx + for idx, id_ in enumerate(prompt_adapter_index_to_id) + if id_ is not None + } + pa_indices = ([ + id_to_index.get(id_, -1) if id_ > 0 else -1 + for id_ in mapping.index_mapping + ]) + + pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + pa_indices = torch.tensor(pa_indices) + return pa_indices, pa_embedding_mapping + + +class PromptAdapterModel(AdapterModel): + + def __init__(self, + prompt_adapter_id=None, + num_virtual_tokens=None, + prompt_embedding=None) -> None: + self.id = prompt_adapter_id + self.prompt_embedding = prompt_embedding + self.num_virtual_tokens = num_virtual_tokens + + @classmethod + def from_local_checkpoint( + cls, + adapter_model_path: str, + prompt_adapter_id: int, + num_virtual_tokens: int, + config: PromptAdapterConfig, + device: str = "cuda", + ) -> "PromptAdapterModel": + + if num_virtual_tokens > config.max_prompt_adapter_token: + raise ValueError( + f'num_virtual_tokens ({num_virtual_tokens}) should be <= ' + f'max_prompt_adapter_token({config.max_prompt_adapter_token})') + + adapters_weights = load_peft_weights(adapter_model_path, device) + prompt_embedding = adapters_weights["prompt_embeddings"].to( + config.prompt_adapter_dtype) + + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) + + +class PromptAdapterModelManager(AdapterModelManager): + """A manager that manages multiple Prompt Adapter models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + """Create a PromptAdapterModel and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + prompt_adapter_config: the PromptAdapter config, + """ + self.model: nn.Module = model + # Dict instead of a Set for compatibility with LRUCache. + self.prompt_adapter_index_to_id: List[ + Optional[int]] = [None] * self.prompt_adapter_slots + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.prompt_adapter_config = prompt_adapter_config + self.model.prompt_adapter_manager = self + self.adapter_type = 'PromptAdapter' + + self.base_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([]) + + self.modules: Dict[str, nn.Module] = {} + self._create_prompt_adapter_modules() + self._last_mapping: Optional[PromptAdapterMapping] = None + + @property + def prompt_adapter_slots(self) -> int: + return self.prompt_adapter_config.max_prompt_adapters + + @property + def adapter_slots(self) -> int: + return self.prompt_adapter_slots + + @property + def capacity(self) -> int: + return self.prompt_adapter_config.max_cpu_prompt_adapters + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + """Move PromptAdapter into a GPU buffer + to be used in the forward pass.""" + if prompt_adapter_id in self._active_adapters: + return False + first_free_slot = next( + ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( + self.prompt_adapter_index_to_id) if prompt_adapter_id is None), + None) + if first_free_slot is None: + raise ValueError("No free prompt_adapter slots") + index, _ = first_free_slot + self._active_adapters[prompt_adapter_id] = None + prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) + logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", + prompt_adapter_model.id, index) + self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id + for _, v in self.modules.items(): + v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) + return True + + def _deactivate_adapter(self, prompt_adapter_id: int): + try: + index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) + self.prompt_adapter_index_to_id[index] = None + for _, v in self.modules.items(): + v.reset_prompt_adapter(index) + except ValueError: + pass + + def _add_adapter(self, prompt_adapter: PromptAdapterModel): + self._registered_adapters[prompt_adapter.id] = prompt_adapter + + def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + base_indices, base_embedding_indices = convert_mapping( + mapping, self.prompt_adapter_index_to_id) + for k, v in self.modules.items(): + v.set_mapping(base_indices, base_embedding_indices) + + def _create_prompt_adapter_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if "VocabParallel" in module.__class__.__name__: + new_module = VocabParallelEmbeddingWithPromptAdapter(module) + new_module.create_prompt_adapter_weights( + self.prompt_adapter_config) + replaced_module = self.replace_submodule( + self.model, module_name, new_module) + self.register_module(module.__class__.__name__, + replaced_module) + replaced_module.set_mapping(self.base_indices, + self.base_embedding_indices) + break + + def replace_submodule(self, model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + def register_module(self, module_name: str, module: nn.Module): + self.modules[module_name] = module + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in PromptAdapterModelManager." + "Use LRUCachePromptAdapterModelManager for pinning" + ) # type: ignore + + def remove_all_adapters(self): + """Remove all PromptAdapterModel from the manager.""" + self._registered_adapters.clear() + self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots + self._active_adapters.clear() + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: PromptAdapterModel) -> bool: + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): + + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[int], bool]): + super().__init__(capacity, deactivate_prompt_adapter_fn) + + +class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): + """A model manager that manages multiple prompt_adapters with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + self.prompt_adapter_config = prompt_adapter_config + super().__init__(model, max_num_seqs, max_num_batched_tokens, + prompt_adapter_config) + self._registered_adapters = PromptAdapterLRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters = PromptAdapterLRUCache( + self.prompt_adapter_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, PromptAdapterModel]: + """List all registered PromptAdapterModel.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + """Add a PromptAdapterModel to the manager.""" + if prompt_adapter.id not in self._registered_adapters: + self._add_adapter(prompt_adapter) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(prompt_adapter.id) + was_added = False + return was_added + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + if prompt_adapter_id not in self._active_adapters and len( + self._active_adapters) >= self.prompt_adapter_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(prompt_adapter_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(prompt_adapter_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) + self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) + return True + + def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): + try: + self._registered_adapters.pin(prompt_adapter_id) + except ValueError as err: + raise ValueError( + "Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered." + ) from err + + def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): + if prompt_adapter_id not in self._active_adapters: + # move adapter to gpu if not already active + self.activate_adapter(prompt_adapter_id) + self._active_adapters.pin(prompt_adapter_id) + + +def create_prompt_adapter_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager, + **kwargs) -> PromptAdapterModelManager: + """Create a PromptAdapterModel for a given model.""" + prompt_adapter_manager = prompt_adapter_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + prompt_adapter_config=prompt_adapter_config, + **kwargs) + return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py new file mode 100644 index 0000000..775dd11 --- /dev/null +++ b/vllm/prompt_adapter/request.py @@ -0,0 +1,34 @@ +import msgspec + +from vllm.adapter_commons.request import AdapterRequest + + +class PromptAdapterRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True, # type: ignore[call-arg] + frozen=True): # type: ignore[call-arg] + """ + Request for a Prompt adapter. + """ + __metaclass__ = AdapterRequest + + prompt_adapter_name: str + prompt_adapter_id: int + prompt_adapter_local_path: str + prompt_adapter_num_virtual_tokens: int + + def __hash__(self): + return super().__hash__() + + @property + def adapter_id(self): + return self.prompt_adapter_id + + @property + def name(self): + return self.prompt_adapter_name + + @property + def local_path(self): + return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/utils.py b/vllm/prompt_adapter/utils.py new file mode 100644 index 0000000..4cde2a0 --- /dev/null +++ b/vllm/prompt_adapter/utils.py @@ -0,0 +1,95 @@ +# code borrowed from: https://github.com/huggingface/peft/blob/v0.12.0/src/peft/utils/save_and_load.py#L420 + +import os +from typing import Optional + +import torch +from huggingface_hub import file_exists, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError +from safetensors.torch import load_file as safe_load_file + +from vllm.platforms import current_platform + +WEIGHTS_NAME = "adapter_model.bin" +SAFETENSORS_WEIGHTS_NAME = "adapter_model.safetensors" + + +# Get current device name based on available devices +def infer_device() -> str: + if current_platform.is_cuda_alike(): + return "cuda" + return "cpu" + + +def load_peft_weights(model_id: str, + device: Optional[str] = None, + **hf_hub_download_kwargs) -> dict: + r""" + A helper method to load the PEFT weights from the HuggingFace Hub or locally + + Args: + model_id (`str`): + The local path to the adapter weights or the name of the adapter to + load from the HuggingFace Hub. + device (`str`): + The device to load the weights onto. + hf_hub_download_kwargs (`dict`): + Additional arguments to pass to the `hf_hub_download` method when + loading from the HuggingFace Hub. + """ + path = (os.path.join(model_id, hf_hub_download_kwargs["subfolder"]) + if hf_hub_download_kwargs.get("subfolder", None) is not None else + model_id) + + if device is None: + device = infer_device() + + if os.path.exists(os.path.join(path, SAFETENSORS_WEIGHTS_NAME)): + filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME) + use_safetensors = True + elif os.path.exists(os.path.join(path, WEIGHTS_NAME)): + filename = os.path.join(path, WEIGHTS_NAME) + use_safetensors = False + else: + token = hf_hub_download_kwargs.get("token", None) + if token is None: + token = hf_hub_download_kwargs.get("use_auth_token", None) + + hub_filename = (os.path.join(hf_hub_download_kwargs["subfolder"], + SAFETENSORS_WEIGHTS_NAME) + if hf_hub_download_kwargs.get("subfolder", None) + is not None else SAFETENSORS_WEIGHTS_NAME) + has_remote_safetensors_file = file_exists( + repo_id=model_id, + filename=hub_filename, + revision=hf_hub_download_kwargs.get("revision", None), + repo_type=hf_hub_download_kwargs.get("repo_type", None), + token=token, + ) + use_safetensors = has_remote_safetensors_file + + if has_remote_safetensors_file: + # Priority 1: load safetensors weights + filename = hf_hub_download( + model_id, + SAFETENSORS_WEIGHTS_NAME, + **hf_hub_download_kwargs, + ) + else: + try: + filename = hf_hub_download(model_id, WEIGHTS_NAME, + **hf_hub_download_kwargs) + except EntryNotFoundError: + raise ValueError( # noqa: B904 + f"Can't find weights for {model_id} in {model_id} or \ + in the Hugging Face Hub. " + f"Please check that the file {WEIGHTS_NAME} or \ + {SAFETENSORS_WEIGHTS_NAME} is present at {model_id}.") + + if use_safetensors: + adapters_weights = safe_load_file(filename, device=device) + else: + adapters_weights = torch.load(filename, + map_location=torch.device(device)) + + return adapters_weights diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py new file mode 100644 index 0000000..ddc1ef8 --- /dev/null +++ b/vllm/prompt_adapter/worker_manager.py @@ -0,0 +1,176 @@ +import logging +from typing import Any, Optional, Set, Type + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, + PromptAdapterModel, + PromptAdapterModelManager, + create_prompt_adapter_manager) +from vllm.prompt_adapter.request import PromptAdapterRequest + +logger = logging.getLogger(__name__) + + +class WorkerPromptAdapterManager(AbstractWorkerManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Every request, the requested prompt_adapters will be + loaded (unless they are already loaded), + and every other prompt_adapter will be unloaded.""" + + _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + device: torch.device, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel + ): + self._adapter_manager: PromptAdapterModelManager + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self._prompt_adapter_model_cls = prompt_adapter_model_cls + self.prompt_adapter_config = prompt_adapter_config + super().__init__(device) + + @property + def is_enabled(self) -> bool: + return True + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._manager_cls, + ) + self._adapter_manager = prompt_adapter_manager + return prompt_adapter_manager.model + + def _load_adapter( + self, prompt_adapter_request: PromptAdapterRequest + ) -> PromptAdapterModel: + try: + prompt_adapter = ( + self._prompt_adapter_model_cls.from_local_checkpoint( + prompt_adapter_request.prompt_adapter_local_path, + prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + num_virtual_tokens=prompt_adapter_request. + prompt_adapter_num_virtual_tokens, + config=self.prompt_adapter_config, + device=str(self.device), + )) + except Exception as e: + raise RuntimeError( + f"Loading prompt_adapter " + f"{prompt_adapter_request.prompt_adapter_local_path}" + f" failed") from e + return prompt_adapter + + def add_dummy_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return True + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Uses an LRU Cache. Every request, the requested + prompt_adapters will be loaded (unless they are already loaded) + and least recently used prompt_adapters will + be unloaded if the cache is above capacity.""" + + _prompt_adapter_manager_cls: Type[ + LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) + self._adapter_manager: LRUCachePromptAdapterModelManager = ( + prompt_adapter_manager) + return prompt_adapter_manager.model + + def _apply_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: + prompt_adapters_map = { + prompt_adapter_request.prompt_adapter_id: prompt_adapter_request + for prompt_adapter_request in prompt_adapter_requests + if prompt_adapter_request + } + if len(prompt_adapters_map + ) > self._adapter_manager.prompt_adapter_slots: + raise RuntimeError( + f"Number of requested prompt_adapters " + f"({len(prompt_adapters_map)}) is greater " + "than the number of GPU prompt_adapter slots " + f"({self._adapter_manager.prompt_adapter_slots}).") + for prompt_adapter in prompt_adapters_map.values(): + self.add_adapter(prompt_adapter) + + def add_adapter(self, + prompt_adapter_request: PromptAdapterRequest) -> bool: + if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( + ): + # Remove before we load the new prompt_adapter to save memory + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + self._adapter_manager.remove_oldest_adapter() + prompt_adapter = self._load_adapter(prompt_adapter_request) + loaded = self._adapter_manager.add_adapter(prompt_adapter) + else: + # If the prompt_adapter is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + prompt_adapter_request.prompt_adapter_id) is not None + self._adapter_manager.activate_adapter( + prompt_adapter_request.prompt_adapter_id) + return loaded diff --git a/vllm/py.typed b/vllm/py.typed new file mode 100644 index 0000000..33b3ad7 --- /dev/null +++ b/vllm/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The vllm package uses inline types. diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py new file mode 100644 index 0000000..4f2ae75 --- /dev/null +++ b/vllm/sampling_params.py @@ -0,0 +1,491 @@ +"""Sampling parameters for text generation.""" +import copy +from dataclasses import dataclass +from enum import Enum, IntEnum +from functools import cached_property +from typing import Any, Callable, Dict, List, Optional, Set, Union + +import msgspec +import torch +from pydantic import BaseModel +from typing_extensions import Annotated + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +_SAMPLING_EPS = 1e-5 +_MAX_TEMP = 1e-2 + + +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + RANDOM_SEED = 2 + + +LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], + Callable[[List[int], List[int], torch.Tensor], + torch.Tensor]] +"""LogitsProcessor is a function that takes a list +of previously generated tokens, the logits tensor +for the next token and, optionally, prompt tokens as a +first argument, and returns a modified tensor of logits +to sample from.""" + + +# maybe make msgspec? +@dataclass +class GuidedDecodingParams: + """One of these fields will be used to build a logit processor.""" + json: Optional[Union[str, Dict]] = None + regex: Optional[str] = None + choice: Optional[List[str]] = None + grammar: Optional[str] = None + json_object: Optional[bool] = None + """These are other options that can be set""" + backend: Optional[str] = None + whitespace_pattern: Optional[str] = None + + @staticmethod + def from_optional( + json: Optional[Union[Dict, BaseModel, str]], + regex: Optional[str] = None, + choice: Optional[List[str]] = None, + grammar: Optional[str] = None, + json_object: Optional[bool] = None, + backend: Optional[str] = None, + whitespace_pattern: Optional[str] = None, + ) -> "GuidedDecodingParams": + # Extract json schemas from pydantic models + if isinstance(json, (BaseModel, type(BaseModel))): + json = json.model_json_schema() + return GuidedDecodingParams( + json=json, + regex=regex, + choice=choice, + grammar=grammar, + json_object=json_object, + backend=backend, + whitespace_pattern=whitespace_pattern, + ) + + def __post_init__(self): + """Validate that some fields are mutually exclusive.""" + guide_count = sum([ + self.json is not None, self.regex is not None, self.choice + is not None, self.grammar is not None, self.json_object is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding but multiple are " + f"specified: {self.__dict__}") + + +class RequestOutputKind(Enum): + # Return entire output so far in every RequestOutput + CUMULATIVE = 0 + # Return only deltas in each RequestOutput + DELTA = 1 + # Do not return intermediate RequestOuputs + FINAL_ONLY = 2 + + +class SamplingParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] + """Sampling parameters for text generation. + + Overall, we follow the sampling parameters from the OpenAI text completion + API (https://platform.openai.com/docs/api-reference/completions/create). + In addition, we support beam search, which is not supported by OpenAI. + + Args: + n: Number of output sequences to return for the given prompt. + best_of: Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. By default, + `best_of` is set to `n`. + presence_penalty: Float that penalizes new tokens based on whether they + appear in the generated text so far. Values > 0 encourage the model + to use new tokens, while values < 0 encourage the model to repeat + tokens. + frequency_penalty: Float that penalizes new tokens based on their + frequency in the generated text so far. Values > 0 encourage the + model to use new tokens, while values < 0 encourage the model to + repeat tokens. + repetition_penalty: Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. + seed: Random seed to use for the generation. + stop: List of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + stop_token_ids: List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens. + include_stop_str_in_output: Whether to include the stop strings in + output text. Defaults to False. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. + max_tokens: Maximum number of tokens to generate per output sequence. + min_tokens: Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated + logprobs: Number of log probabilities to return per output token. + When set to None, no probability is returned. If set to a non-None + value, the result includes the log probabilities of the specified + number of most likely tokens, as well as the chosen tokens. + Note that the implementation follows the OpenAI API: The API will + always return the log probability of the sampled token, so there + may be up to `logprobs+1` elements in the response. + prompt_logprobs: Number of log probabilities to return per prompt token. + detokenize: Whether to detokenize the output. Defaults to True. + skip_special_tokens: Whether to skip special tokens in the output. + spaces_between_special_tokens: Whether to add spaces between special + tokens in the output. Defaults to True. + logits_processors: List of functions that modify logits based on + previously generated tokens, and optionally prompt tokens as + a first argument. + truncate_prompt_tokens: If set to an integer k, will use only the last k + tokens from the prompt (i.e., left truncation). Defaults to None + (i.e., no truncation). + guided_decoding: If provided, the engine will construct a guided + decoding logits processor from these parameters. Defaults to None. + logit_bias: If provided, the engine will construct a logits processor + that applies these logit biases. Defaults to None. + allowed_token_ids: If provided, the engine will construct a logits + processor which only retains scores for the given token ids. + Defaults to None. + """ + + n: int = 1 + best_of: Optional[int] = None + _real_n: Optional[int] = None + presence_penalty: float = 0.0 + frequency_penalty: float = 0.0 + repetition_penalty: float = 1.0 + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = -1 + min_p: float = 0.0 + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = None + stop_token_ids: Optional[List[int]] = None + ignore_eos: bool = False + max_tokens: Optional[int] = 16 + min_tokens: int = 0 + logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None + # NOTE: This parameter is only exposed at the engine level for now. + # It is not exposed in the OpenAI API server, as the OpenAI API does + # not support returning only a list of token IDs. + detokenize: bool = True + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + # Optional[List[LogitsProcessor]] type. We use Any here because + # Optional[List[LogitsProcessor]] type is not supported by msgspec. + logits_processors: Optional[Any] = None + include_stop_str_in_output: bool = False + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=1)]] = None + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE + + # The below fields are not supposed to be used as an input. + # They are set in post_init. + output_text_buffer_length: int = 0 + _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + + # Fields used to construct logits processors + guided_decoding: Optional[GuidedDecodingParams] = None + logit_bias: Optional[Dict[int, float]] = None + allowed_token_ids: Optional[List[int]] = None + + @staticmethod + def from_optional( + n: Optional[int] = 1, + best_of: Optional[int] = None, + presence_penalty: Optional[float] = 0.0, + frequency_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 1.0, + top_k: int = -1, + min_p: float = 0.0, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: Optional[int] = 16, + min_tokens: int = 0, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + detokenize: bool = True, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=1)]] = None, + output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, + guided_decoding: Optional[GuidedDecodingParams] = None, + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None, + allowed_token_ids: Optional[List[int]] = None, + ) -> "SamplingParams": + if logit_bias is not None: + logit_bias = { + int(token): bias + for token, bias in logit_bias.items() + } + + return SamplingParams( + n=1 if n is None else n, + best_of=best_of, + presence_penalty=0.0 + if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 + if frequency_penalty is None else frequency_penalty, + repetition_penalty=1.0 + if repetition_penalty is None else repetition_penalty, + temperature=1.0 if temperature is None else temperature, + top_p=1.0 if top_p is None else top_p, + top_k=top_k, + min_p=min_p, + seed=seed, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_stop_str_in_output, + ignore_eos=ignore_eos, + max_tokens=max_tokens, + min_tokens=min_tokens, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + detokenize=detokenize, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, + truncate_prompt_tokens=truncate_prompt_tokens, + output_kind=output_kind, + guided_decoding=guided_decoding, + logit_bias=logit_bias, + allowed_token_ids=allowed_token_ids, + ) + + def __post_init__(self) -> None: + # how we deal with `best_of``: + # if `best_of`` is not set, we default to `n`; + # if `best_of`` is set, we set `n`` to `best_of`, + # and set `_real_n`` to the original `n`. + # when we return the result, we will check + # if we need to return `n` or `_real_n` results + if self.best_of: + if self.best_of < self.n: + raise ValueError( + f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}.") + self._real_n = self.n + self.n = self.best_of + if 0 < self.temperature < _MAX_TEMP: + logger.warning( + "temperature %s is less than %s, which may cause numerical " + "errors nan or inf in tensors. We have maxed it out to %s.", + self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature = max(self.temperature, _MAX_TEMP) + if self.seed == -1: + self.seed = None + else: + self.seed = self.seed + if self.stop is None: + self.stop = [] + elif isinstance(self.stop, str): + self.stop = [self.stop] + else: + self.stop = list(self.stop) + if self.stop_token_ids is None: + self.stop_token_ids = [] + else: + self.stop_token_ids = list(self.stop_token_ids) + self.logprobs = 1 if self.logprobs is True else self.logprobs + self.prompt_logprobs = (1 if self.prompt_logprobs is True else + self.prompt_logprobs) + + # Number of characters to hold back for stop string evaluation + # until sequence is finished. + if self.stop and not self.include_stop_str_in_output: + self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 + + self._verify_args() + + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 + self.min_p = 0.0 + self._verify_greedy_sampling() + # eos_token_id is added to this by the engine + self._all_stop_token_ids = set(self.stop_token_ids) + + def _verify_args(self) -> None: + if not isinstance(self.n, int): + raise ValueError(f"n must be an int, but is of " + f"type {type(self.n)}") + if self.n < 1: + raise ValueError(f"n must be at least 1, got {self.n}.") + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError("presence_penalty must be in [-2, 2], got " + f"{self.presence_penalty}.") + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError("frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}.") + if not 0.0 < self.repetition_penalty <= 2.0: + raise ValueError("repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}.") + if self.temperature < 0.0: + raise ValueError( + f"temperature must be non-negative, got {self.temperature}.") + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, " + f"got {self.top_k}.") + if not isinstance(self.top_k, int): + raise TypeError( + f"top_k must be an integer, got {type(self.top_k).__name__}") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0, 1], got " + f"{self.min_p}.") + if self.max_tokens is not None and self.max_tokens < 1: + raise ValueError( + f"max_tokens must be at least 1, got {self.max_tokens}.") + if self.min_tokens < 0: + raise ValueError(f"min_tokens must be greater than or equal to 0, " + f"got {self.min_tokens}.") + if self.max_tokens is not None and self.min_tokens > self.max_tokens: + raise ValueError( + f"min_tokens must be less than or equal to " + f"max_tokens={self.max_tokens}, got {self.min_tokens}.") + if self.logprobs is not None and self.logprobs < 0: + raise ValueError( + f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: + raise ValueError(f"prompt_logprobs must be non-negative, got " + f"{self.prompt_logprobs}.") + if (self.truncate_prompt_tokens is not None + and self.truncate_prompt_tokens < 1): + raise ValueError(f"truncate_prompt_tokens must be >= 1, " + f"got {self.truncate_prompt_tokens}") + assert isinstance(self.stop, list) + if any(not stop_str for stop_str in self.stop): + raise ValueError("stop cannot contain an empty string.") + if self.stop and not self.detokenize: + raise ValueError( + "stop strings are only supported when detokenize is True. " + "Set detokenize=True to use stop.") + if self.best_of != self._real_n and self.output_kind == ( + RequestOutputKind.DELTA): + raise ValueError("best_of must equal n to use output_kind=DELTA") + + def _verify_greedy_sampling(self) -> None: + if self.n > 1: + raise ValueError("n must be 1 when using greedy sampling, " + f"got {self.n}.") + + def update_from_generation_config( + self, + generation_config: Dict[str, Any], + model_eos_token_id: Optional[int] = None) -> None: + """Update if there are non-default values from generation_config""" + + if model_eos_token_id is not None: + # Add the eos token id into the sampling_params to support + # min_tokens processing. + self._all_stop_token_ids.add(model_eos_token_id) + + # Update eos_token_id for generation + if (eos_ids := generation_config.get("eos_token_id")) is not None: + # it can be either int or list of int + eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + if model_eos_token_id is not None: + # We don't need to include the primary eos_token_id in + # stop_token_ids since it's handled separately for stopping + # purposes. + eos_ids.discard(model_eos_token_id) + if eos_ids: + self._all_stop_token_ids.update(eos_ids) + if not self.ignore_eos: + eos_ids.update(self.stop_token_ids) + self.stop_token_ids = list(eos_ids) + + @cached_property + def sampling_type(self) -> SamplingType: + if self.temperature < _SAMPLING_EPS: + return SamplingType.GREEDY + if self.seed is not None: + return SamplingType.RANDOM_SEED + return SamplingType.RANDOM + + @property + def all_stop_token_ids(self) -> Set[int]: + return self._all_stop_token_ids + + def clone(self) -> "SamplingParams": + """Deep copy excluding LogitsProcessor objects. + + LogitsProcessor objects are excluded because they may contain an + arbitrary, nontrivial amount of data. + See https://github.com/vllm-project/vllm/issues/3087 + """ + + logit_processor_refs = None if self.logits_processors is None else { + id(lp): lp + for lp in self.logits_processors + } + return copy.deepcopy(self, memo=logit_processor_refs) + + def __repr__(self) -> str: + return ( + f"SamplingParams(n={self.n}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"seed={self.seed}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"min_tokens={self.min_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens}, " + f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " + f"guided_decoding={self.guided_decoding}") + + +class BeamSearchParams( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True): # type: ignore[call-arg] + """Beam search parameters for text generation.""" + beam_width: int + max_tokens: int + ignore_eos: bool = False + temperature: float = 0.0 + length_penalty: float = 1.0 diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py new file mode 100644 index 0000000..eb491dd --- /dev/null +++ b/vllm/scalar_type.py @@ -0,0 +1,35 @@ +from ._core_ext import NanRepr, ScalarType + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, + NanRepr.EXTD_RANGE_MAX_MIN.value) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value) + + # "gptq" types + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/vllm/scripts.py b/vllm/scripts.py new file mode 100644 index 0000000..4e4c071 --- /dev/null +++ b/vllm/scripts.py @@ -0,0 +1,201 @@ +# The CLI entrypoint to vLLM. +import argparse +import os +import signal +import sys +from typing import List, Optional + +import uvloop +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from vllm.engine.arg_utils import EngineArgs +from vllm.entrypoints.openai.api_server import run_server +from vllm.entrypoints.openai.cli_args import (make_arg_parser, + validate_parsed_serve_args) +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +def register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def serve(args: argparse.Namespace) -> None: + # The default value of `--model` + if args.model != EngineArgs.model: + raise ValueError( + "With `vllm serve`, you should provide the model as a " + "positional argument instead of via the `--model` option.") + + # EngineArgs expects the model name to be passed as --model. + args.model = args.model_tag + + uvloop.run(run_server(args)) + + +def interactive_cli(args: argparse.Namespace) -> None: + register_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + if args.command == "complete": + complete(model_name, openai_client) + elif args.command == "chat": + chat(args.system_prompt, model_name, openai_client) + + +def complete(model_name: str, client: OpenAI) -> None: + print("Please enter prompt to complete:") + while True: + input_prompt = input("> ") + + completion = client.completions.create(model=model_name, + prompt=input_prompt) + output = completion.choices[0].text + print(output) + + +def chat(system_prompt: Optional[str], model_name: str, + client: OpenAI) -> None: + conversation: List[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + input_message = input("> ") + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create(model=model_name, + messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + +def _add_query_options( + parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:8000/v1", + help="url of the running OpenAI-Compatible RESTful API server") + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " + "the first model in list models API call.")) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + )) + return parser + + +def env_setup(): + # The safest multiprocessing method is `spawn`, as the default `fork` method + # is not compatible with some accelerators. The default method will be + # changing in future versions of Python, so we should use it explicitly when + # possible. + # + # We only set it here in the CLI entrypoint, because changing to `spawn` + # could break some existing code using vLLM as a library. `spawn` will cause + # unexpected behavior if the code is not protected by + # `if __name__ == "__main__":`. + # + # References: + # - https://docs.python.org/3/library/multiprocessing.html#contexts-and-start-methods + # - https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing + # - https://pytorch.org/docs/stable/multiprocessing.html#sharing-cuda-tensors + # - https://docs.habana.ai/en/latest/PyTorch/Getting_Started_with_PyTorch_and_Gaudi/Getting_Started_with_PyTorch.html?highlight=multiprocessing#torch-multiprocessing-for-dataloaders + if "VLLM_WORKER_MULTIPROC_METHOD" not in os.environ: + logger.debug("Setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + +def main(): + env_setup() + + parser = FlexibleArgumentParser(description="vLLM CLI") + subparsers = parser.add_subparsers(required=True, dest="subparser") + + serve_parser = subparsers.add_parser( + "serve", + help="Start the vLLM OpenAI Compatible API server", + usage="vllm serve [options]") + serve_parser.add_argument("model_tag", + type=str, + help="The model tag to serve") + serve_parser.add_argument( + "--config", + type=str, + default='', + required=False, + help="Read CLI options from a config file." + "Must be a YAML with the following options:" + "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server" + ) + serve_parser = make_arg_parser(serve_parser) + serve_parser.set_defaults(dispatch_function=serve) + + complete_parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " + "via the running API server"), + usage="vllm complete [options]") + _add_query_options(complete_parser) + complete_parser.set_defaults(dispatch_function=interactive_cli, + command="complete") + + chat_parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server", + usage="vllm chat [options]") + _add_query_options(chat_parser) + chat_parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=("The system prompt to be added to the chat template, " + "used for models that support system prompts.")) + chat_parser.set_defaults(dispatch_function=interactive_cli, command="chat") + + args = parser.parse_args() + if args.subparser == "serve": + validate_parsed_serve_args(args) + + # One of the sub commands should be executed. + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/vllm/sequence.py b/vllm/sequence.py new file mode 100644 index 0000000..3bb35ea --- /dev/null +++ b/vllm/sequence.py @@ -0,0 +1,1380 @@ +"""Sequence and its related classes.""" +import copy +import enum +from abc import ABC, abstractmethod +from array import array +from collections import defaultdict +from dataclasses import dataclass +from functools import cached_property, reduce +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional +from typing import Sequence as GenericSequence +from typing import Set, Tuple, Union, cast + +import msgspec +import torch + +from vllm.inputs import EncoderDecoderLLMInputs, LLMInputs +from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + +if TYPE_CHECKING: + from vllm.multimodal.base import MultiModalDataDict + +VLLM_TOKEN_ID_ARRAY_TYPE = "l" + +VLLM_INVALID_TOKEN_ID = -1 + + +# We use dataclass for now because it is used for +# openai server output, and msgspec is not serializable. +# TODO(sang): Fix it. +@dataclass +class Logprob: + """Infos for supporting OpenAI compatible logprobs and token ranks. + + Attributes: + logprob: The logprob of chosen token + rank: The vocab rank of chosen token (>=1) + decoded_token: The decoded chosen token index + """ + logprob: float + rank: Optional[int] = None + decoded_token: Optional[str] = None + + +# {token_id -> logprob} per each sequence group. None if the corresponding +# sequence group doesn't require prompt logprob. +PromptLogprobs = List[Optional[Dict[int, Logprob]]] +# {token_id -> logprob} for each sequence group. +SampleLogprobs = List[Dict[int, Logprob]] + + +class SequenceStatus(enum.IntEnum): + """Status of a sequence.""" + WAITING = 0 + RUNNING = 1 + SWAPPED = 2 + # Note: anything after SWAPPED (2) will be considered + # as a finished status. + FINISHED_STOPPED = 3 + FINISHED_LENGTH_CAPPED = 4 + FINISHED_ABORTED = 5 + FINISHED_IGNORED = 6 + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status > SequenceStatus.SWAPPED + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == SequenceStatus.FINISHED_ABORTED: + finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + # The ignored sequences are the sequences whose prompt lengths + # are longer than the model's length cap. Therefore, the stop + # reason should also be "length" as in OpenAI API. + finish_reason = "length" + else: + finish_reason = None + return finish_reason + + +class SequenceStage(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + + +@dataclass +class RequestMetrics: + """Metrics associated with a request. + + Attributes: + arrival_time: The time when the request arrived. + first_scheduled_time: The time when the request was first scheduled. + first_token_time: The time when the first token was generated. + time_in_queue: The time the request spent in the queue. + finished_time: The time when the request was finished. + scheduler_time: The time spent in the scheduler when this request was + being considered by the scheduler. + model_forward_time: The time spent in the model forward pass when this + request was in the batch. + model_execute_time: The time spent in the model execute function. This + will include model forward, block/sync across + workers, cpu-gpu sync time and sampling time. + """ + arrival_time: float + last_token_time: float + first_scheduled_time: Optional[float] + first_token_time: Optional[float] + time_in_queue: Optional[float] + finished_time: Optional[float] = None + scheduler_time: Optional[float] = None + model_forward_time: Optional[float] = None + model_execute_time: Optional[float] = None + + +class SequenceDataDelta( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta SequenceData to send to workers per step.""" + # A new token to be appended to existing SequenceData. + new_output_token_ids: List[int] + # Overwriting existing `cumulative_logprob` + new_cumulative_logprob: float + # Overwriting existing `num_computed_tokens`. + new_num_computed_tokens: int + # Overwriting existing `stage`. + new_stage: SequenceStage + + +class SequenceData(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] + """Data associated with a sequence. + + Args: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. Set to an empty list if + None. + + Attributes: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + # NOTE: we cannot use Union[List, array] because msgspec cannot support + # union of 2 list types. + _prompt_token_ids: array + _output_token_ids: array = msgspec.field( + default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, [])) + + ### The below fields should not be passed as an argument ### + _cumulative_logprob: float = 0.0 + _prompt_token_ids_tuple: Tuple[int, + ...] = msgspec.field(default_factory=tuple) + # The number of tokens that are computed (that run against the model). + _num_computed_tokens: int = 0 + _stage: SequenceStage = SequenceStage.PREFILL + _cached_all_token_ids: List[int] = msgspec.field(default_factory=list) + + # It is used to get delta input. It is reset when `get_delta_and_reset` + # is called. + _new_appended_tokens: List[int] = msgspec.field(default_factory=list) + + # It is used to compute mrope_position_ids. + _mrope_position_delta: Optional[int] = None + + @staticmethod + def from_token_counts(*token_counts: Tuple[int, int]) -> "SequenceData": + if len(token_counts) == 0: + return SequenceData.from_seqs([]) + + arrs = [ + array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count + for token_id, count in token_counts + ] + + return SequenceData(reduce(array.__add__, arrs)) + + @staticmethod + def from_seqs( + prompt_token_ids: GenericSequence[int], + output_token_ids: Optional[GenericSequence[int]] = None, + ) -> "SequenceData": + prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + prompt_token_ids) + + if output_token_ids is None: + return SequenceData(prompt_token_ids_arr) + + output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE, + output_token_ids) + + return SequenceData(prompt_token_ids_arr, + _output_token_ids=output_token_ids_arr) + + def __post_init__(self) -> None: + assert self._prompt_token_ids.typecode == "l" + assert self._output_token_ids.typecode == "l" + self._prompt_token_ids_tuple: Tuple[int, ...] = tuple( + self._prompt_token_ids) + self._update_cached_all_tokens() + + def _update_cached_all_tokens(self): + assert isinstance(self._prompt_token_ids, array) + assert isinstance(self._output_token_ids, array) + self._cached_all_token_ids: List[int] = list(self._prompt_token_ids + + self._output_token_ids) + + @property + def cumulative_logprob(self) -> float: + return self._cumulative_logprob + + @property + def prompt_token_ids(self) -> Tuple[int, ...]: + return self._prompt_token_ids_tuple + + @prompt_token_ids.setter + def prompt_token_ids(self, new_prompt_token_ids) -> None: + raise NotImplementedError + + @property + def prompt_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + return self._prompt_token_ids + + @property + def output_token_ids(self) -> Tuple[int, ...]: + return tuple(self._output_token_ids) + + @output_token_ids.setter + def output_token_ids(self, new_output_token_ids: List[int]) -> None: + self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids) + self._update_cached_all_tokens() + + @property + def output_token_ids_array(self) -> array: + """Return the prompt token ids in array type. + + Note that the array is in "I" type, and it is not compatible + with torch.long (2 bytes vs 4 bytes). So beware of the usage. + """ + assert isinstance(self._output_token_ids, array) + return self._output_token_ids + + @property + def mrope_position_delta(self) -> Optional[int]: + return self._mrope_position_delta + + @mrope_position_delta.setter + def mrope_position_delta(self, new_mrope_position_delta): + self._mrope_position_delta = new_mrope_position_delta + + def append_token_id(self, token_id: int, logprob: float) -> None: + self._output_token_ids.append(token_id) + self._new_appended_tokens.append(token_id) + self._cached_all_token_ids.append(token_id) + self._cumulative_logprob += logprob + + def get_len(self) -> int: + return len(self._output_token_ids) + len(self._prompt_token_ids) + + def get_prompt_len(self) -> int: + return len(self._prompt_token_ids) + + def get_output_len(self) -> int: + return len(self._output_token_ids) + + def get_token_ids(self) -> List[int]: + return self._cached_all_token_ids + + def get_prefix_token_ids( + self, num_tokens: int + ) -> Tuple[Tuple[int, ...], Optional[Tuple[int, ...]]]: + """Get prefix tokens, and make the return value hashable""" + prompt_length = self.get_prompt_len() + if num_tokens > prompt_length: + return (self._prompt_token_ids_tuple, + tuple(self._output_token_ids[:num_tokens - prompt_length])) + else: + return (self._prompt_token_ids_tuple[:num_tokens], None) + + def get_num_computed_tokens(self) -> int: + """Return the number of prefill tokens that are already computed.""" + return self._num_computed_tokens + + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" + self._num_computed_tokens += num_new_computed_tokens + assert self._num_computed_tokens <= self.get_len(), ( + self._num_computed_tokens, self.get_len()) + # If all tokens are computed, it means it is in decoding phase. + if self.get_num_uncomputed_tokens() == 0: + self._stage = SequenceStage.DECODE + + def reset_state_for_recompute(self) -> None: + """Reset the number of computed tokens from this sequence. It is + supposed to be called when a sequence needs to be started from + the beginning again (e.g., sequence is preempted). + """ + self._num_computed_tokens = 0 + self._stage = SequenceStage.PREFILL + self._new_appended_tokens = [] + + def get_num_uncomputed_tokens(self) -> int: + """Return the number of prefill tokens that are not computed.""" + # we use `get_len()` which includes prompt_len + output_len instead + # of prompt_len here. This is because during recompute we need to + # prefill for both prompt and output. + return self.get_len() - self.get_num_computed_tokens() + + def get_last_token_id(self) -> int: + if not self._output_token_ids: + return self._prompt_token_ids[-1] + return self._output_token_ids[-1] + + def get_prompt_token_ids(self) -> Tuple[int, ...]: + return self.prompt_token_ids + + def get_output_token_ids(self) -> Tuple[int, ...]: + return self.output_token_ids + + def get_delta_and_reset(self) -> SequenceDataDelta: + delta = SequenceDataDelta(self._new_appended_tokens, + self._cumulative_logprob, + self.get_num_computed_tokens(), self.stage) + # Reset delta state. + self._new_appended_tokens = [] + return delta + + def apply_delta(self, delta: SequenceDataDelta): + self._num_computed_tokens = delta.new_num_computed_tokens + self._cumulative_logprob = delta.new_cumulative_logprob + self._stage = delta.new_stage + self._output_token_ids.extend(delta.new_output_token_ids) + self._cached_all_token_ids.extend(delta.new_output_token_ids) + + @property + def stage(self) -> SequenceStage: + return self._stage + + def __repr__(self) -> str: + return (f"SequenceData(" + f"prompt_token_ids={self._prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"get_num_computed_tokens={self.get_num_computed_tokens()}") + + +class Sequence: + """Stores the data, status, and block information of a sequence. + + The sequence is constructed from the LLMInputs instance passed + in through the `inputs` constructor argument. + + For encoder/decoder models, LLMInputs encapsulates both a + decoder and encoder prompt, creating an ambiguity about which + prompt to construct the sequence from. The `from_decoder_prompt` + constructor argument signals whether to construct the Sequence + from the LLMInputs decoder prompt, or encoder prompt. + + Args: + seq_id: The ID of the sequence. + inputs: The inputs of the sequence. + block_size: The block size of the sequence. Should be the same as the + block size used by the block manager and cache engine. + eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM. + lora_request: LoRA request. + prompt_adapter_request: Prompt Adapter request. + from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt + (True) or encoder prompt (False.) Must be True + for decoder-only model. + + """ + + def __init__( + self, + seq_id: int, + inputs: "LLMInputs", + block_size: int, + eos_token_id: Optional[int] = None, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + from_decoder_prompt: bool = True, + ) -> None: + self.seq_id = seq_id + self.inputs = inputs + self.block_size = block_size + self.eos_token_id = eos_token_id + self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request + self.from_decoder_prompt = from_decoder_prompt + + # For decoder-only models, a Sequence is constructed + # from an LLMInputs instance (the `inputs` arg.) + # + # For encoder/decoder models the same `inputs` + # instance could be utilized to construct either an + # encoder sequence or a decoder sequence, because + # `LLMInputs` has both decoder- and encoder-oriented + # member variables (i.e. it encapsulates both an encoder + # and a decoder prompt.) The decision of which type of sequence + # to generate is determined by the `from_decoder_prompt` argument. + # + # When constructing a encoder sequence + # (`from_decoder_prompt` False) it matters that + # the `LLMInputs` instance stored in `inputs` is valid + # in the sense that its encoder-related member variables are + # populated; below, an exception is raised if this is + # not the case. + # + # When constructing a decoder sequence (`from_decoder_prompt` True) + # it does not matter whether `inputs` has its encoder-related + # member variables populated. + if not (from_decoder_prompt + or is_valid_encoder_decoder_llm_inputs(inputs)): + raise ValueError("Cannot extract encoder input prompt from " + f"invalid input {inputs}; did you forget the " + "encoder input prompt fields?") + + self.data = SequenceData.from_seqs(self.prompt_token_ids) + self.output_logprobs: SampleLogprobs = [] + self.output_text = "" + + self.status = SequenceStatus.WAITING + self.stop_reason: Union[int, str, None] = None + + # These are used to keep track of delta outputs + self._last_output_token_ids_offset: int = 0 + self._last_output_text_offset: int = 0 + + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[List[str]] = None + + @property + def n_blocks(self) -> int: + return (self.get_len() + self.block_size - 1) // self.block_size + + @cached_property + def prompt(self) -> Optional[str]: + # Select decoder or encoder input prompt str, as appropriate + prompt_key: str = ("prompt" + if self.from_decoder_prompt else "encoder_prompt") + + return cast(Optional[str], self.inputs.get(prompt_key)) + + @cached_property + def prompt_token_ids(self) -> List[int]: + # Select decoder or encoder input prompt token ids, as appropriate + prompt_token_ids_key: str = ("prompt_token_ids" + if self.from_decoder_prompt else + "encoder_prompt_token_ids") + + # Cache computed prompt token ids + return cast(List[int], self.inputs.get(prompt_token_ids_key)) + + @property + def multi_modal_data(self) -> "MultiModalDataDict": + if self.inputs.get("multi_modal_data") and self.inputs.get( + "encoder_multi_modal_data"): + raise ValueError( + "Multi-modal data in both encoder and decoder is not supported." + ) + inputs = self.inputs + return self.inputs.get("multi_modal_data") or (cast( + EncoderDecoderLLMInputs, + inputs).get("encoder_multi_modal_data")) or {} + + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + return self.inputs.get("mm_processor_kwargs") or {} + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + def get_output_text_to_return(self, buffer_length: int, + delta: bool) -> str: + """If delta is True, only new text since the last call to + this method is returned""" + + # We return the full output text if the sequence is finished. + truncate = buffer_length and not self.is_finished() + if not delta: + return self.output_text[:-buffer_length] if truncate else ( + self.output_text) + length = len(self.output_text) + if truncate: + length -= buffer_length + last_offset = self._last_output_text_offset + if last_offset < length: + self._last_output_text_offset = length + return self.output_text[last_offset:length] + return "" + + def get_output_token_ids_to_return( + self, delta: bool) -> Union[GenericSequence[int], int]: + """If delta is True, only new tokens since the last call to + this method are returned""" + if not delta: + return self.get_output_token_ids() + + output_len = self.get_output_len() + + # Get the number of new tokens + num_new_tokens = output_len - self._last_output_token_ids_offset + self._last_output_token_ids_offset = output_len + + # Return new tokens + if num_new_tokens == 1: + # Optimization for single decode token case + # (which is what we have most of the time) + return self.data._cached_all_token_ids[-1] + + return self.data._cached_all_token_ids[-num_new_tokens:] + + def hash_of_block(self, logical_idx: int) -> int: + # TODO This can produce incorrect hash when block size > prompt size + + # Compute the number of tokens in the sequence + # TODO: The current hashing function is O(L^2). We should optimize + # this in the future. + num_tokens = self.num_hashed_tokens_of_block(logical_idx) + hashed_tokens = self.data.get_prefix_token_ids(num_tokens) + return hash((hashed_tokens, self.lora_int_id)) + + def num_hashed_tokens_of_block(self, logical_idx: int): + return logical_idx * self.block_size + self.block_size + + def reset_state_for_recompute(self): + """Reset the sequence states for recomputation.""" + self.data.reset_state_for_recompute() + + def append_token_id(self, token_id: int, logprobs: Dict[int, + Logprob]) -> None: + assert token_id in logprobs + self.output_logprobs.append(logprobs) + self.data.append_token_id(token_id, logprobs[token_id].logprob) + + def get_len(self) -> int: + return self.data.get_len() + + def get_prompt_len(self) -> int: + return self.data.get_prompt_len() + + def get_output_len(self) -> int: + return self.data.get_output_len() + + def get_token_ids(self) -> List[int]: + return self.data.get_token_ids() + + def get_prompt_token_ids(self) -> Tuple[int, ...]: + return self.data.get_prompt_token_ids() + + def get_last_token_id(self) -> int: + return self.data.get_last_token_id() + + def get_output_token_ids(self) -> Tuple[int, ...]: + return self.data.get_output_token_ids() + + def get_cumulative_logprob(self) -> float: + return self.data.cumulative_logprob + + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.status) + + def fork(self, new_seq_id: int) -> "Sequence": + new_seq = copy.deepcopy(self) + new_seq.seq_id = new_seq_id + return new_seq + + def get_num_new_tokens(self) -> int: + """Get the number of new tokens to be computed. + + Returns: + The new number of tokens to be computed. I.e., 1 for decode, or + the remaining prompt size for prefill. + """ + if self.data.stage == SequenceStage.DECODE: + return 1 + return self.data.get_num_uncomputed_tokens() + + def is_prefill(self) -> bool: + return self.data.stage == SequenceStage.PREFILL + + def __repr__(self) -> str: + return (f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={self.n_blocks}, ") + + +class SequenceGroupState(msgspec.Struct, + omit_defaults=True): # type: ignore[call-arg] + """Mutable state tied to a specific sequence group""" + + # for multi-step decoding + num_steps: int = 1 + current_step: int = 0 + + @property + def remaining_steps(self) -> int: + return self.num_steps - self.current_step + + +class SequenceGroup: + """A group of sequences that are generated from the same prompt. + + Args: + request_id: The ID of the request. + seqs: The list of sequences. + sampling_params: The sampling parameters used to generate the outputs. + arrival_time: The arrival time of the request. + lora_request: LoRA request. + embeddings: The embeddings vectors of the prompt of the sequence group + for an embedding model. + pooling_params: The pooling parameters used to generate the pooling + for an embedding model. + encoder_seq: Optional, the single encoder sequence. Should be None + unless you are working with an encoder/decoder model. + trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request. + priority: User-defined priority of the request. + """ + + def __init__( + self, + request_id: str, + seqs: List[Sequence], + arrival_time: float, + sampling_params: Optional[SamplingParams] = None, + lora_request: Optional[LoRARequest] = None, + embeddings: Optional[List[float]] = None, + pooling_params: Optional[PoolingParams] = None, + encoder_seq: Optional[Sequence] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + ) -> None: + self.request_id = request_id + self.seqs = seqs + self.arrival_time = arrival_time + self.is_single_seq = len(seqs) == 1 + self.seqs_dict = {seq.seq_id: seq for seq in seqs} + + self.sampling_params = sampling_params + self.metrics = RequestMetrics(arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None) + self.lora_request = lora_request + self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() + self.embeddings = embeddings + self.pooling_params = pooling_params + self.prompt_adapter_request = prompt_adapter_request + self.encoder_seq = encoder_seq + self.trace_headers = trace_headers + self.priority = priority + + self.cached_request_output = None + + @property + def prompt(self) -> Optional[str]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return self.seqs[0].prompt + + @property + def prompt_token_ids(self) -> List[int]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return self.seqs[0].prompt_token_ids + + @property + def encoder_prompt(self) -> Optional[str]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt is distinct + # from the decoder's. + return (self.encoder_seq.prompt + if self.encoder_seq is not None else None) + + @property + def encoder_prompt_token_ids(self) -> Optional[List[int]]: + # There are either 0 or 1 encoder sequences + # If one is present, its prompt token ids are + # distinct from the decoder's. + return (self.encoder_seq.prompt_token_ids + if self.encoder_seq is not None else None) + + @property + def multi_modal_data(self) -> "MultiModalDataDict": + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return self.seqs[0].multi_modal_data + + @property + def mm_processor_kwargs(self) -> Dict[str, Any]: + # As with multi-modal data, all sequences in the group should have the + # same processor kwargs (i.e., mm_processor_kwargs are optionally + # provided per request; note that are independent of whether the model + # decoder-only or an encoder-decoder). + return self.seqs[0].mm_processor_kwargs + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + if self.prompt_adapter_request else 0 + + def init_multi_step(self, num_steps: int) -> None: + self.state.num_steps = num_steps + self.state.current_step = 0 + + def init_multi_step_from_lookahead_slots(self, num_lookahead_slots: int, + num_scheduler_steps: int, + is_multi_step: bool, + enable_chunking: bool) -> None: + + if not is_multi_step: + self.init_multi_step(num_steps=num_scheduler_steps) + return + + # Multi-Step case + is_prefill = self.is_prefill() + + # The asserts below reflect the expectations of the current system. + if is_prefill and enable_chunking: + assert num_lookahead_slots == num_scheduler_steps + self.init_multi_step(num_steps=num_lookahead_slots) + else: + is_decode: bool = not is_prefill + # If it is a prefill, num_lookahead_slots must be 0 + assert num_lookahead_slots == 0 or is_decode + # If it is a decode, num_lookahead_slots + 1 must match + # the scheduler steps. + assert num_lookahead_slots + 1 == num_scheduler_steps or is_prefill + self.init_multi_step(num_steps=num_lookahead_slots + 1) + + def get_last_latency(self, now: float) -> Optional[float]: + """Sets the last token time for Request level timings.""" + # If still in prefill phase, raise Error. + if self.is_prefill(): + raise ValueError( + "seq_group.get_last_latency() should not be called " + "if the seq_group is in prefill phase.") + + # Otherwise return token latency. + latency = now - self.metrics.last_token_time + self.metrics.last_token_time = now + return latency + + def maybe_set_first_token_time(self, time: float) -> None: + """Sets the first token time for Request level timings.""" + # Note: in a case where a sequence_group is swapped and + # recomputed, the time between iterations is counted + # in TPOT, rather than recalculating TTFT (since from the ) + # POV of the user, there is simply a long generation delay. + if (self.metrics.first_token_time is None + and self.seqs[0].get_output_len() == 1): + self.metrics.first_token_time = time + + def maybe_set_first_scheduled_time(self, time: float) -> None: + """Sets the first scheduled time and time in queue for Request + level timings.""" + if self.metrics.first_scheduled_time is None: + self.metrics.first_scheduled_time = time + self.metrics.time_in_queue = time - self.metrics.arrival_time + + def set_finished_time(self, time: Optional[float]) -> None: + """Sets the finished time for Request level timings.""" + self.metrics.finished_time = time + + def get_max_num_running_seqs(self) -> int: + """The maximum number of sequences running in parallel in the remaining + lifetime of the request.""" + if self.sampling_params: + n = self.sampling_params.n + assert isinstance(n, int) + if n > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `n` sequences + # running. + return n + # At sampling stages, return the number of actual sequences + # that are not finished yet. + return self.num_unfinished_seqs() + + def get_seqs( + self, + status: Optional[SequenceStatus] = None, + ) -> List[Sequence]: + if status is None: + return self.seqs + + if self.is_single_seq: + return self.seqs if self.seqs[0].status == status else [] + + return [seq for seq in self.seqs if seq.status == status] + + def is_encoder_decoder(self) -> bool: + return self.encoder_seq is not None + + def get_encoder_seq(self) -> Optional[Sequence]: + return self.encoder_seq + + def get_unfinished_seqs(self) -> List[Sequence]: + if self.is_single_seq: + return self.seqs if not self.seqs[0].is_finished() else [] + + return [seq for seq in self.seqs if not seq.is_finished()] + + def get_finished_seqs(self) -> List[Sequence]: + if self.is_single_seq: + return self.seqs if self.seqs[0].is_finished() else [] + + return [seq for seq in self.seqs if seq.is_finished()] + + def update_num_computed_tokens(self, num_new_computed_tokens: int): + """Update number of tokens computed so far.""" + for seq in self.seqs: + if not seq.is_finished(): + seq.data.update_num_computed_tokens(num_new_computed_tokens) + + def get_num_uncomputed_tokens(self) -> int: + num_uncomputed_tokens = 0 + for seq in self.seqs: + if not seq.is_finished(): + num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() + return num_uncomputed_tokens + + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + # Optimization. We don't need to call get_seqs if we don't need to + # filter by states. + if status is None: + return len(self.seqs) + + if self.is_single_seq: + return 1 if self.seqs[0].status == status else 0 + + return len(self.get_seqs(status)) + + def num_unfinished_seqs(self) -> int: + if self.is_single_seq: + return 1 if not self.seqs[0].is_finished() else 0 + + return len(self.get_unfinished_seqs()) + + def num_finished_seqs(self) -> int: + if self.is_single_seq: + return 1 if self.seqs[0].is_finished() else 0 + + return len(self.get_finished_seqs()) + + def find(self, seq_id: int) -> Sequence: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + return self.seqs_dict[seq_id] + + def add(self, seq: Sequence) -> None: + if seq.seq_id in self.seqs_dict: + raise ValueError(f"Sequence {seq.seq_id} already exists.") + self.seqs_dict[seq.seq_id] = seq + self.seqs.append(seq) + self.is_single_seq = len(self.seqs) == 1 + + def remove(self, seq_id: int) -> None: + seq = self.seqs_dict.pop(seq_id, None) + if seq is None: + raise ValueError(f"Sequence {seq_id} not found.") + self.seqs.remove(seq) + self.is_single_seq = len(self.seqs) == 1 + + def is_finished(self) -> bool: + if self.is_single_seq: + return self.seqs[0].is_finished() + + return all(seq.is_finished() for seq in self.seqs) + + def is_prefill(self) -> bool: + # Every sequence should be in the same stage. + return self.seqs[0].is_prefill() + + def __repr__(self) -> str: + return (f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs)})") + + +class SequenceGroupMetadataDelta( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Delta of SequenceGroupMetadata. + + After sending the first SequenceGroupMetadata, vLLM scheduler + only sends delta to reduce the data payload size. + """ + seq_data_delta: Dict[int, SequenceDataDelta] + request_id: str + block_tables: Dict[int, List[int]] + is_prompt: bool + do_sample: bool = True + token_chunk_size: Optional[int] = None + computed_block_nums: Optional[List[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + + +class SequenceGroupMetadata( + msgspec.Struct, + tag=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """Metadata for a sequence group. Used to create `AttentionMetadata`. + + Args: + request_id: The ID of the request. + is_prompt: Whether the request is at prompt stage. + seq_data: The sequence data. (Seq id -> sequence data) + sampling_params: The sampling parameters used to generate the outputs. + block_tables: The block tables. (Seq id -> list of physical block + numbers) + do_sample: True if sampling is required. Sampling is not required when + e.g., prefill is chunked, and the current iteration only computes + query tokens for prefill, we don't need sampling. + token_chunk_size: The number of tokens to be processed (per sequence). + None if chunking is not required. + lora_request: LoRA request. + computed_block_nums: The block numbers that are already computed, + used in prefix caching. + state: Internal state tied to this sequence group. + multi_modal_data: Multi modal data. + mm_processor_kwargs: Multimodal input processor / mapper overrides. + encoder_seq_data: Optional sequence data for encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + cross_block_table: Optional cross-attention block table associated + with the encoder prompt + (SequenceGroup.encoder_seq). Should be None + unless you are working with an encoder/decoder + model. + prompt_adapter_request: Prompt Adapter request. + """ + + request_id: str + is_prompt: bool + seq_data: Dict[int, SequenceData] + sampling_params: Optional[SamplingParams] + block_tables: Dict[int, List[int]] + do_sample: bool = True + pooling_params: Optional[PoolingParams] = None + lora_request: Optional[LoRARequest] = None + computed_block_nums: Optional[List[int]] = None + state: Optional[SequenceGroupState] = msgspec.field( + default_factory=lambda: SequenceGroupState()) + # "MultiModalDataDict" types. We have to use Any due to msgspec + # doesn't allow to have union of 2 different dicts. + multi_modal_data: Optional[Any] = None + mm_processor_kwargs: Optional[Dict[str, Any]] = None + encoder_seq_data: Optional[SequenceData] = None + cross_block_table: Optional[List[int]] = None + prompt_adapter_request: Optional[PromptAdapterRequest] = None + token_chunk_size: Optional[int] = None + + ### Stateful fields that are lazily defined. ### + # The number of speculative tokens adopted in this request. + # None means specuative decoding is not used. + # Zero means speculative decoding is disabled for some reasons. + # TODO: We should maintain this states out of the sequence group. + num_speculative_tokens: Optional[int] = None + + def __post_init__(self): + if self.seq_data is not None and self.token_chunk_size is None: + if self.is_prompt: + self.token_chunk_size = next(iter( + self.seq_data.values())).get_len() + else: + self.token_chunk_size = 1 + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + if self.prompt_adapter_request else 0 + + # Multi-Step Chunked-Prefill property + @property + def is_single_step_prompt(self) -> bool: + # do_sample is true, only when the token_chunk_size matches the + # num_uncomputed_tokens of the sequence. This indicates that + # the prompt will finish processing in a single `execute_model` + # step. + return self.is_prompt and self.do_sample + + def get_first_seq_id(self) -> int: + # This is an efficient way of fetching the seq_id when + # we know this SequenceGroup has only one sequence. + return next(iter(self.seq_data)) + + def apply_delta(self, + sequence_group_metadata_delta: SequenceGroupMetadataDelta): + for id, delta in sequence_group_metadata_delta.seq_data_delta.items(): + self.seq_data[id].apply_delta(delta) + assert self.request_id == sequence_group_metadata_delta.request_id + self.block_tables = sequence_group_metadata_delta.block_tables + self.token_chunk_size = sequence_group_metadata_delta.token_chunk_size + self.do_sample = sequence_group_metadata_delta.do_sample + self.is_prompt = sequence_group_metadata_delta.is_prompt + + def finish_step(self) -> None: + assert self.state is not None + assert self.state.current_step < self.state.num_steps, \ + f"current step {self.state.current_step}, num_steps {self.state.num_steps}" # noqa + self.state.current_step += 1 + + +class SequenceOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """The model output associated with a sequence. + + Args: + parent_seq_id: The ID of the parent sequence (for forking in beam + search). + output_token: The output token ID. + logprobs: The logprobs of the output token. + (Token id -> logP(x_i+1 | x_0, ..., x_i)) + """ + parent_seq_id: int + output_token: int + logprobs: Dict[int, Logprob] + + def __repr__(self) -> str: + return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceOutput): + raise NotImplementedError() + equal = (self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token) + log_probs_equal = other.logprobs == self.logprobs + return equal and log_probs_equal + + +class SequenceGroupOutput(ABC): + """The base class for model outputs associated with a sequence group.""" + + @abstractmethod + def __repr__(self) -> str: + pass + + @abstractmethod + def __eq__(self, other: object) -> bool: + pass + + +class CompletionSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + __metaclass__ = SequenceGroupOutput + """The model output associated with a completion sequence group.""" + samples: List[SequenceOutput] + # Prompt logprob for each prompt query token. + prompt_logprobs: Optional[PromptLogprobs] + + def __repr__(self) -> str: + return (f"CompletionSequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CompletionSequenceGroupOutput): + raise NotImplementedError() + return (self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs) + + +class EmbeddingSequenceGroupOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True, # type: ignore[call-arg] +): + """The model output associated with an embedding sequence group.""" + __metaclass__ = SequenceGroupOutput + embeddings: List[int] + + def __repr__(self) -> str: + return (f"EmbeddingSequenceGroupOutput(" + f"embeddings_shape={len(self.embeddings)})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, EmbeddingSequenceGroupOutput): + raise NotImplementedError() + return self.embeddings == other.embeddings + + +# cannot use msgspec.Struct here because Dynamo does not support it +@dataclass +class IntermediateTensors: + """For all pipeline stages except the last, we need to return the hidden + states and residuals to be sent to the next stage. This data structure + contains the hidden states and residuals for a request. + """ + + tensors: Dict[str, torch.Tensor] + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value): + self.tensors[key] = value + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"IntermediateTensors(tensors={self.tensors})" + + +class PoolerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """The output from a pooling operation in the embedding model.""" + outputs: List[EmbeddingSequenceGroupOutput] + + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + +def get_all_seq_ids( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + return [seq_id for sg in seq_group_metadata_list for seq_id in sg.seq_data] + + +def get_all_seq_ids_and_request_ids( + seq_group_metadata_list: List[SequenceGroupMetadata] +) -> Tuple[List[int], Dict[str, Set[int]]]: + """Given a list of SequenceGroupMetadata, create a list of all + sequence ids. + """ + seq_ids: List[int] = [] + request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) + for sg in seq_group_metadata_list: + for seq_id in sg.seq_data: + seq_ids.append(seq_id) + request_id_seq_ids_mapping[sg.request_id].add(seq_id) + return seq_ids, request_id_seq_ids_mapping + + +class HiddenStates(msgspec.Struct, array_like=True, + omit_defaults=True): # type: ignore[call-arg] + """Hidden states corresponding to in-progress sequences. + Used in speculative decoding to pass hidden states from + the target model to the proposer model. + + seq_ids are the sequence ids of each entry of the batch + dimension of the hidden_states tensor""" + # Scorer hidden states. For prefill step, it is used for hidden states of + # all tokens, whereas for decode step, it use used for last accepted tokens. + hidden_states: torch.Tensor + # The sequence group metadata list. Only needed for decode step. + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + # Scorer hidden states of the 2nd last token proposed by the proposer ( + # irrespective of whether it was accepted or not). Only used for cases when + # last proposed token is accepted (i.e., in case of bonus tokens). For the + # case of no bonus tokens, these are ignored. + second_last_token_hidden_states: Optional[torch.Tensor] = None + + _seq_ids: List[int] = msgspec.field(default_factory=list) + + def __post_init__(self): + if self.seq_group_metadata_list is not None: + assert len(self.seq_group_metadata_list) == len(self.hidden_states) + self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list) + + @property + def seq_ids(self) -> List[int]: + return self._seq_ids + + def update(self, + hidden_states: torch.Tensor, + seq_group_metadata_list: List[SequenceGroupMetadata], + second_last_token_hidden_states: Optional[torch.Tensor] = None): + """Update hidden states from target model invocation. Only used for + decode steps""" + assert len(seq_group_metadata_list) == len(hidden_states) + self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list)) + self.hidden_states = torch.cat([self.hidden_states, hidden_states]) + + if self.second_last_token_hidden_states is not None: + # Adding dummy hidden_states to this to maintain same shape + self.second_last_token_hidden_states = torch.cat([ + self.second_last_token_hidden_states, + torch.zeros_like(hidden_states) + if second_last_token_hidden_states is None else + second_last_token_hidden_states + ]) + + def prune(self, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + """Prune to provided list of sequence ids. Only used for decode steps. + """ + # Currently this prunes all seq_ids not present in + # seq_group_metadata_list which might cause problems where a sequence + # may be "paused" then "resumed" later. This should only prune sequences + # which are confirmed to be aborted. + seq_ids = get_all_seq_ids(seq_group_metadata_list) + if seq_ids != self._seq_ids: + # Batch contents changed - prune removed sequences. + index = [self._seq_ids.index(seq_id) for seq_id in seq_ids] + self.hidden_states = self.hidden_states[index] + if self.second_last_token_hidden_states is not None: + self.second_last_token_hidden_states = self\ + .second_last_token_hidden_states[index] + self._seq_ids = seq_ids + + def expand_with_bonus_tokens( + self, seq_with_bonus_token_in_last_step: set) -> None: + """Expand hidden states for sequences with bonus tokens. This is in + alignment with `MultiStepWorker._expand_execute_model_request`.""" + if self.second_last_token_hidden_states is None \ + or not seq_with_bonus_token_in_last_step: + return + + index = [] + for seq_id in self._seq_ids: + i = self._seq_ids.index(seq_id) + if seq_id in seq_with_bonus_token_in_last_step: + index.append(i + len(self._seq_ids)) + index.append(i) + + self.hidden_states = torch.cat( + [self.hidden_states, self.second_last_token_hidden_states])[index] + + +class ExecuteModelRequest( + msgspec.Struct, + array_like=True, # type: ignore[call-arg] + omit_defaults=True): # type: ignore[call-arg] + """The model execution request, containing CPU metadata only. The LLM + engine should create an instance of this class for each request batch.""" + # The sequence group metadata list. + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, + int]] = msgspec.field(default_factory=list) + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, + int]] = msgspec.field(default_factory=list) + # Blocks to copy. Source to dest block. + blocks_to_copy: List[Tuple[int, int]] = msgspec.field(default_factory=list) + # Virtual engine ID for pipeline parallel. + virtual_engine: int = 0 + # The number of slots for lookahead decoding. + num_lookahead_slots: int = 0 + # The number of requests in the running queue. + running_queue_size: int = 0 + # Optional hidden states from prior step. + previous_hidden_states: Optional[HiddenStates] = None + # The number of forward steps to run. + num_steps: int = 1 + # Finished request ids since last step. + finished_requests_ids: List[str] = msgspec.field(default_factory=list) + # The last sampled token ids for multi step decoding. + last_sampled_token_ids: Optional[torch.Tensor] = None + # Async callback + async_callback: Optional[Callable] = None + + @property + def is_first_multi_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None + return first_seq_group.state.current_step == 0 + + @property + def is_last_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + assert first_seq_group.state is not None + return first_seq_group.state.remaining_steps == 1 + + @property + def current_step(self) -> int: + # TODO(will) make this be able to handle batches with variable number of + # steps + assert len(self.seq_group_metadata_list) > 0 + state = self.seq_group_metadata_list[0].state + assert state is not None + return state.current_step + + def clone( + self, seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]] + ) -> "ExecuteModelRequest": + """Clone the request with a new sequence group metadata list.""" + return ExecuteModelRequest( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=self.blocks_to_swap_in.copy(), + blocks_to_swap_out=self.blocks_to_swap_out.copy(), + blocks_to_copy=self.blocks_to_copy.copy(), + virtual_engine=self.virtual_engine, + num_lookahead_slots=self.num_lookahead_slots, + running_queue_size=self.running_queue_size, + previous_hidden_states=self.previous_hidden_states, + num_steps=self.num_steps, + finished_requests_ids=self.finished_requests_ids, + last_sampled_token_ids=self.last_sampled_token_ids.clone() + if self.last_sampled_token_ids is not None else None, + async_callback=self.async_callback) diff --git a/vllm/spec_decode/__init__.py b/vllm/spec_decode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/spec_decode/__pycache__/__init__.cpython-310.pyc b/vllm/spec_decode/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..213d9b8 Binary files /dev/null and b/vllm/spec_decode/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/batch_expansion.cpython-310.pyc b/vllm/spec_decode/__pycache__/batch_expansion.cpython-310.pyc new file mode 100644 index 0000000..c07646f Binary files /dev/null and b/vllm/spec_decode/__pycache__/batch_expansion.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/draft_model_runner.cpython-310.pyc b/vllm/spec_decode/__pycache__/draft_model_runner.cpython-310.pyc new file mode 100644 index 0000000..a298a18 Binary files /dev/null and b/vllm/spec_decode/__pycache__/draft_model_runner.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/interfaces.cpython-310.pyc b/vllm/spec_decode/__pycache__/interfaces.cpython-310.pyc new file mode 100644 index 0000000..ed50914 Binary files /dev/null and b/vllm/spec_decode/__pycache__/interfaces.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/medusa_worker.cpython-310.pyc b/vllm/spec_decode/__pycache__/medusa_worker.cpython-310.pyc new file mode 100644 index 0000000..d887416 Binary files /dev/null and b/vllm/spec_decode/__pycache__/medusa_worker.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/metrics.cpython-310.pyc b/vllm/spec_decode/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000..f37fc63 Binary files /dev/null and b/vllm/spec_decode/__pycache__/metrics.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-310.pyc b/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-310.pyc new file mode 100644 index 0000000..da7501c Binary files /dev/null and b/vllm/spec_decode/__pycache__/mlp_speculator_worker.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/mqa_scorer.cpython-310.pyc b/vllm/spec_decode/__pycache__/mqa_scorer.cpython-310.pyc new file mode 100644 index 0000000..344f6a5 Binary files /dev/null and b/vllm/spec_decode/__pycache__/mqa_scorer.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/multi_step_worker.cpython-310.pyc b/vllm/spec_decode/__pycache__/multi_step_worker.cpython-310.pyc new file mode 100644 index 0000000..3eae3f9 Binary files /dev/null and b/vllm/spec_decode/__pycache__/multi_step_worker.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/ngram_worker.cpython-310.pyc b/vllm/spec_decode/__pycache__/ngram_worker.cpython-310.pyc new file mode 100644 index 0000000..0da7d4e Binary files /dev/null and b/vllm/spec_decode/__pycache__/ngram_worker.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-310.pyc b/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-310.pyc new file mode 100644 index 0000000..2b3eb51 Binary files /dev/null and b/vllm/spec_decode/__pycache__/proposer_worker_base.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-310.pyc b/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-310.pyc new file mode 100644 index 0000000..c47e835 Binary files /dev/null and b/vllm/spec_decode/__pycache__/smaller_tp_proposer_worker.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-310.pyc b/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-310.pyc new file mode 100644 index 0000000..a4e914a Binary files /dev/null and b/vllm/spec_decode/__pycache__/spec_decode_worker.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/target_model_runner.cpython-310.pyc b/vllm/spec_decode/__pycache__/target_model_runner.cpython-310.pyc new file mode 100644 index 0000000..af40176 Binary files /dev/null and b/vllm/spec_decode/__pycache__/target_model_runner.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/top1_proposer.cpython-310.pyc b/vllm/spec_decode/__pycache__/top1_proposer.cpython-310.pyc new file mode 100644 index 0000000..0e72f46 Binary files /dev/null and b/vllm/spec_decode/__pycache__/top1_proposer.cpython-310.pyc differ diff --git a/vllm/spec_decode/__pycache__/util.cpython-310.pyc b/vllm/spec_decode/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000..2f06c7f Binary files /dev/null and b/vllm/spec_decode/__pycache__/util.cpython-310.pyc differ diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py new file mode 100644 index 0000000..59e71cc --- /dev/null +++ b/vllm/spec_decode/batch_expansion.py @@ -0,0 +1,455 @@ +from array import array +from itertools import chain, count +from typing import Iterator, List, Optional, Tuple + +import torch + +from vllm import SamplingParams +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE, + ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len + +SeqId = int +TargetSeqId = int +TokenId = int + +DEFAULT_SIMPLE_SAMPLING_PARAMS = SamplingParams() + + +class BatchExpansionTop1Scorer(SpeculativeScorer): + """Implements a speculative scorer that uses batch expansion to get + probabilities of speculative tokens according to the scoring model. + + Batch expansion converts a list of sequences and multiple query positions + to a new batch of sequences, each with a single query position. This allows + for MQA-like scoring in speculative decoding without requiring an MQA + kernel. + + It is strictly less efficient than MQA scoring. + + It only supports scoring the top1 proposal tokens of the proposer, instead + of topk/tree. + """ + + @nvtx_range("BatchExpansionTop1Scorer.score_proposals") + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + """Score the proposed tokens via the scorer model. + + This converts each input sequence to a set of k+1 target sequences. The + target sequences have the unique continuations to be scored and a + unique sequence ID that is different from all input sequence ids. + + If a speculative sequence length would exceed the max model length, then + no speculation is produced for that sequence. + + Args: + execute_model_req: The execution request. + proposals: The speculative proposals to score. + Returns: + SpeculativeScores: The scores of each speculative token, along with + which sequences were ignored during scoring. + """ + + # TODO(cade) perform this on GPU to remove blocking call. + proposal_lens_list = proposals.proposal_lens.tolist() + proposal_token_ids_list = proposals.proposal_token_ids.tolist() + + # Filter the list to ignore invalid proposals. + proposal_token_ids_list_without_skips = [ + proposals for proposals in proposal_token_ids_list + if VLLM_INVALID_TOKEN_ID not in proposals + ] + + (spec_indices, non_spec_indices, target_seq_group_metadata_list, + num_scoring_tokens) = self._expand_batch( + seq_group_metadata_list=execute_model_req.seq_group_metadata_list, + proposal_token_ids_list=proposal_token_ids_list_without_skips, + proposal_lens_list=proposal_lens_list, + ) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + assert len(target_sampler_output) == 1, "expected single-step output" + target_sampler_output = target_sampler_output[0] + + if not non_spec_indices: + # All sequence groups in batch have spec decoding enabled + contracted = self._contract_batch_all_spec( + target_sampler_output=target_sampler_output, + proposals=proposals, + ) + else: + # Batch has a mix of spec decode enabled and disabled seq groups + contracted = self._contract_batch( + contracted_bs=len(execute_model_req.seq_group_metadata_list), + target_sampler_output=target_sampler_output, + proposals=proposals, + num_scoring_tokens=num_scoring_tokens, + non_spec_indices=non_spec_indices, + spec_indices=spec_indices, + k=execute_model_req.num_lookahead_slots, + ) + + all_tokens, all_probs, spec_logprobs, all_hidden_states = contracted + return SpeculativeScores( + probs=all_probs, + token_ids=all_tokens, + logprobs=spec_logprobs, + hidden_states=all_hidden_states, + ) + + def _expand_batch( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids_list: List[List[TokenId]], + proposal_lens_list: List[int], + ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: + """Given the input sequences and potentially multiple corresponding + proposal tokens, create a new batch where each sequence has a single + query token. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + (spec_seqs, spec_indices), (non_spec_seqs, non_spec_indices) = \ + split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) + + target_seq_group_metadata_list = self._create_scoring_model_input( + seq_group_metadata_list=spec_seqs, + proposal_token_ids=proposal_token_ids_list, + # NOTE: We determine the seq ids in the expanded batch using the + # full seq_group_metadata_list, instead of only spec_seqs. + target_seq_ids_iter=self._create_target_seq_id_iterator( + seq_ids=get_all_seq_ids(seq_group_metadata_list)), + ) + + num_scoring_tokens = len(target_seq_group_metadata_list) + target_seq_group_metadata_list.extend(non_spec_seqs) + + return (spec_indices, non_spec_indices, target_seq_group_metadata_list, + num_scoring_tokens) + + def _contract_batch( + self, contracted_bs: int, target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, num_scoring_tokens: int, + non_spec_indices: List[int], spec_indices: List[int], k: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + contracted_bs is the original batch size, and the batch size that the + target_sampler_output will be contracted to. + """ + (target_token_ids, target_probs, target_logprobs, target_hidden_states, + non_spec_target_token_ids, non_spec_target_probs, + non_spec_target_logprobs, + non_spec_target_hidden_states) = self._split_scoring_output( + target_sampler_output, num_scoring_tokens) + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + expanded_batch_size, k = proposals.proposal_token_ids.shape + + # The number of tokens in the expanded batch used for speculation is + # equal to the total expanded batch size minus the number of samples for + # non-speculative sequences. + non_spec_expanded_bs = len(non_spec_target_token_ids) + spec_expanded_bs = expanded_batch_size - non_spec_expanded_bs + + target_token_ids = target_token_ids.reshape(spec_expanded_bs, k + 1) + target_probs = target_probs.reshape(*target_token_ids.shape, + self._vocab_size) + target_logprobs = target_logprobs.reshape(target_probs.shape) + + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + all_tokens = target_token_ids.new_full(size=(contracted_bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + + if target_sampler_output.hidden_states is not None: + all_hidden_states = target_hidden_states.new_zeros( + size=(contracted_bs, k + 1, target_hidden_states.shape[-1])) + else: + all_hidden_states = None + + if non_spec_indices: + all_tokens[non_spec_indices, :1] = \ + non_spec_target_token_ids.unsqueeze(1) + all_probs[non_spec_indices, :1, :] = \ + non_spec_target_probs.unsqueeze(1) + all_logprobs[non_spec_indices, :1, :] = \ + non_spec_target_logprobs.unsqueeze(1) + if all_hidden_states is not None: + assert non_spec_target_hidden_states is not None + all_hidden_states[non_spec_indices, :1, :] = \ + non_spec_target_hidden_states.unsqueeze(1) + + if spec_indices: + all_tokens[spec_indices] = target_token_ids + all_probs[spec_indices] = target_probs + all_logprobs[spec_indices] = target_logprobs + if all_hidden_states is not None: + all_hidden_states[spec_indices] = target_hidden_states + + return all_tokens, all_probs, all_logprobs, all_hidden_states + + def _contract_batch_all_spec( + self, + target_sampler_output: SamplerOutput, + proposals: SpeculativeProposals, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor]]: + """Contract the expanded batch back into its original size. + This maps the scores of speculative tokens back to their original + sequences. + + It assumes all sequences in the batch were previously expanded. + """ + + # Map distinct sequences used to score each token + # of shape [batch_size * k + 1] back to [batch_size, k + 1]. + contracted_bs, k = proposals.proposal_token_ids.shape + + # Reshape tensors to original batch size + target_token_ids = target_sampler_output.sampled_token_ids.reshape( + contracted_bs, k + 1) + target_probs = target_sampler_output.sampled_token_probs.reshape( + *target_token_ids.shape, self._vocab_size) + target_logprobs = target_sampler_output.logprobs.reshape( + target_probs.shape) + target_hidden_states = target_sampler_output.hidden_states + if target_hidden_states is not None: + target_hidden_states = target_hidden_states.reshape( + *target_token_ids.shape, target_hidden_states.shape[-1]) + + return (target_token_ids, target_probs, target_logprobs, + target_hidden_states) + + def _create_scoring_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given the original input sequences and proposed tokens from the draft + model, create a list of target sequences that can be used for scoring. + + target_seq_ids_iter provides sequence ids for the expanded batch, + fulfilling the requirement that no seq id in the expanded batch is equal + to the seq id in the original batch. + """ + + if not seq_group_metadata_list: + return [] + + target_seq_group_metadata = list( + chain.from_iterable( + self._create_target_seq_group_metadata( + seq_group_metadata, + proposal_token_ids, + i, + target_seq_ids_iter, + ) for i, seq_group_metadata in enumerate( + seq_group_metadata_list))) + + return target_seq_group_metadata + + def _create_target_seq_group_metadata( + self, + input_seq_group_metadata: SequenceGroupMetadata, + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] + batch_index: int, + target_seq_ids_iter: Iterator[TargetSeqId], + ) -> List[SequenceGroupMetadata]: + """Given an input sequence group metadata and a list of draft tokens, + create a list of target SequenceGroupMetadata, one for each + token id that needs to be scored. + + Naive speculative decoding requires K target model scores, one for each + draft model token. However one can add a bonus token such that if each + token is accepted, then a final token may be sampled from the model. + This function creates K+1 target SequenceGroupMetadata to take + advantage of the bonus token. + """ + assert not input_seq_group_metadata.is_prompt, ( + "Speculating on " + "prompts not yet supported") + assert len(input_seq_group_metadata.seq_data) == 1, ( + "Beam search " + "not supported in speculative decoding") + input_seq_id = next(iter(input_seq_group_metadata.seq_data.keys())) + + token_ids_to_score = self._get_token_ids_to_score( + proposal_token_ids[batch_index]) + + # Use simpler sampling parameters apart from for final token + # (in particular don't do seeded sampling) since those sampled tokens + # aren't used. + # We don't replace the sampling_params in the greedy case because + # this also controls whether the probs get modified in the sampler + # (see use of _modify_greedy_probs_inplace there). + sampling_params = input_seq_group_metadata.sampling_params + non_bonus_sampling_params = DEFAULT_SIMPLE_SAMPLING_PARAMS \ + if sampling_params.temperature else sampling_params + + target_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + last_index = len(token_ids_to_score) - 1 + for i, token_ids in enumerate(token_ids_to_score): + target_sampling_params = sampling_params if i == last_index \ + else non_bonus_sampling_params + target_seq_group_metadata_list.append( + self._create_single_target_seq_group_metadata( + input_seq_group_metadata, + input_seq_id, + next(target_seq_ids_iter), + token_ids, + sampling_params=target_sampling_params, + )) + + return target_seq_group_metadata_list + + @staticmethod + def _create_single_target_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, + seq_id: SeqId, + target_seq_id: TargetSeqId, + token_ids: List[TokenId], + sampling_params: SamplingParams, + ) -> SequenceGroupMetadata: + """Create a single target SequenceGroupMetadata. + + Args: + seq_group_metadata: The metadata for the input sequence. + seq_id: The input sequence ID. + target_seq_id: The corresponding target sequence ID. + token_ids: The list of token ids that are to be appended to the + input sequence. + """ + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_token_ids = seq_data.prompt_token_ids_array + new_output_token_ids = [*seq_data.get_output_token_ids(), *token_ids] + + new_seq_data_dict = { + target_seq_id: + SequenceData( + prompt_token_ids, + _output_token_ids=array(VLLM_TOKEN_ID_ARRAY_TYPE, + new_output_token_ids), + ), + } + # This is a hack. Technically, spec decoding should compute + # num_lookahead slots at one shot, but instead, it expands the batch + # and evaluate one by one right now. context_len is seq_len - 1 because + # the kv cache is filled by a previous batch in the batch expansion. + for data in new_seq_data_dict.values(): + data.update_num_computed_tokens(data.get_len() - 1) + + return SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + + @staticmethod + def _split_scoring_output( + sampler_output: SamplerOutput, num_scoring_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, torch.Tensor, + torch.Tensor, Optional[torch.Tensor]]: + """Split the target model output into speculative and non-speculative + output. + """ + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + # + # First samples are from speculative scoring, latter samples are non- + # speculative samples. + split_sizes = (num_scoring_tokens, + sampler_output.sampled_token_ids.numel() - + num_scoring_tokens) + (spec_probs, non_spec_probs + ) = sampler_output.sampled_token_probs.split(split_sizes) + (spec_sampled_tokens, non_spec_sampled_tokens + ) = sampler_output.sampled_token_ids.flatten().split(split_sizes) + ( + spec_logprobs, + non_spec_logprobs, + ) = sampler_output.logprobs.split(split_sizes) + + if sampler_output.hidden_states is not None: + ( + spec_hidden_states, + non_spec_hidden_states, + ) = sampler_output.hidden_states.split(split_sizes) + else: + spec_hidden_states, non_spec_hidden_states = None, None + + return (spec_sampled_tokens, spec_probs, spec_logprobs, + spec_hidden_states, non_spec_sampled_tokens, non_spec_probs, + non_spec_logprobs, non_spec_hidden_states) + + @staticmethod + def _create_target_seq_id_iterator( + seq_ids: List[SeqId]) -> Iterator[TargetSeqId]: + """Create an iterator for creating target sequence ids. + Target sequence ids are distinct from sequence ids because we create a + distinct target sequence id for each proposal token to be scored. + + This implementation increments a counter starting at 1 + max of all + provided input sequence ids. + """ + return count(start=max(seq_ids) + 1) + + @staticmethod + def _get_token_ids_to_score( + full_spec_token_ids: List[TokenId] # shape: [k] + ) -> List[List[TokenId]]: + """Given an int tensor of proposal token ids, return a list of + token ids that should be scored. + + Returns k+1 output lists. The additional one is used for generating the + bonus token. + + Example: + Input: [0, 1, 2, 3] (k=4) + Output: (k+1 lists) + [] + [0] + [0, 1] + [0, 1, 2] + [0, 1, 2, 3] + """ + empty_token_ids: List[TokenId] = [] + + token_ids_to_score = [empty_token_ids] + token_ids_to_score.extend(full_spec_token_ids[:i + 1] + for i in range(len(full_spec_token_ids))) + return token_ids_to_score diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py new file mode 100644 index 0000000..aaf6ec5 --- /dev/null +++ b/vllm/spec_decode/draft_model_runner.py @@ -0,0 +1,327 @@ +from typing import List, Optional + +import torch + +from vllm.forward_context import set_forward_context +from vllm.model_executor.layers.sampler import SamplerOutput + +try: + try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata + except (ModuleNotFoundError, ImportError): + # vllm_flash_attn is not installed, try the ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) +except (ModuleNotFoundError, ImportError) as err: + raise RuntimeError( + "Draft model speculative decoding currently only supports" + "CUDA and ROCm flash attention backend.") from err + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalInputs +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, + ModelRunner) + +logger = init_logger(__name__) + +# A flag to enable debug prints for the updated input tensors +# before each step. +debug_advance_input = False +# A flag to allow GPU advance step for draft model runner. +# Set to False for debugging. +allow_gpu_advance_step = True + + +class TP1DraftModelRunner(ModelRunner): + """Specialized model runner for speculative decoding draft model. + Since the draft model always execute k forward passes consecutively to + generate k speculative tokens in a single speculative decoding step, + we could get rid of most CPU-GPU synchronization and data transfer + overheads by keeping model input and output tensors on GPU all the time. + + TODOs: + 1. Currently supports only flash-attn, add support for other attn_backends. + 2. Support TP > 1 (this requires some designs because we do not expect + any broadcasting inside execute_model). + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + ): + if return_hidden_states: + raise ValueError( + "return_hidden_states is not supported for TP1DraftModelRunner." + ) + + super().__init__( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + load_config=load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + return_hidden_states=return_hidden_states, + observability_config=observability_config, + ) + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + + def _gpu_advance_step( + self, model_input: ModelInputForGPUWithSamplingMetadata, + last_output: SamplerOutput + ) -> ModelInputForGPUWithSamplingMetadata: + # Currently, we expect "decode mode" only + assert not model_input.is_prompt + + # Get num_seqs + num_seqs = len(model_input.seq_lens) + num_queries = len(model_input.query_lens) + + # Get output tokens GPU tensor + sampled_token_ids = last_output.sampled_token_ids + assert sampled_token_ids is not None + + # Update attn_metadata + attn_metadata = model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + + attn_metadata.advance_step(model_input, sampled_token_ids, + self.block_size, num_seqs, num_queries) + + # Update sampling_metadata + sampling_metadata = model_input.sampling_metadata + self._update_sampling_metadata(sampling_metadata, num_seqs, + num_queries) + + # Create new input + new_model_input = self._model_input_cls( + input_tokens=model_input.input_tokens, + input_positions=model_input.input_positions, + attn_metadata=attn_metadata, + seq_lens=attn_metadata.seq_lens, + query_lens=model_input.query_lens, + lora_mapping=model_input.lora_mapping, + lora_requests=model_input.lora_requests, + multi_modal_kwargs=model_input.multi_modal_kwargs, + sampling_metadata=model_input.sampling_metadata, + is_prompt=False, + ) + + # Ensure we skip CPU samples + assert new_model_input.sampling_metadata.skip_sampler_cpu_output is True + # We can reuse sampling tensors since every decode iteration is the same + new_model_input.sampling_metadata.reuse_sampling_tensors = True + + if debug_advance_input: + logger.debug("NEW INPUT: ") + logger.debug(" input_tokens = %s", new_model_input.input_tokens) + logger.debug(" input_positions = %s", + new_model_input.input_positions) + logger.debug(" seq_lens = %d", new_model_input.seq_lens) + logger.debug(" query_lens = %d", new_model_input.query_lens) + logger.debug(" attn_metadata:") + logger.debug(" seq_lens_tensor: %s", + attn_metadata.seq_lens_tensor) + logger.debug(" slot_mapping: %s", attn_metadata.slot_mapping) + logger.debug(" block_tables: %s", attn_metadata.block_tables) + + return new_model_input + + def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): + """Determines if draft_model_runner GPU multi-step can be used. + Currently required conditions are: + 1. Only decodes + 2. Only flash-attn + 3. No LORA + 4. No prompt_adapter_config + """ + if not allow_gpu_advance_step: + return False + + # We allow multi-step GPU only in decode mode + for seq_group in execute_model_req.seq_group_metadata_list: + if seq_group.is_prompt: + return False + + # TODO: Add support for other attn backends + if self.attn_backend.get_name() != "flash-attn": + return False + + # TODO: Add support for LORA + if self.lora_config: + return False + + # TODO: Add soft-tuning prompt adapter support + return not self.prompt_adapter_config + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + """Executes num_steps forward passes with advacement of input tensors + on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. + + Optimizations used: + 1. Input tensors are updated on the GPU directly + 2. Skips GPU=>CPU serialization of sampler outputs (we don't need + them since we do batch expansion later that uses GPU outputs) + 3. Reuses sampling tensors (since we run only decodes and they have + a repeating sampling logic) + """ + + # When num_steps == 1, we execute the fallback here for the GPU + # advance_step, which runs prepare_inputs on CPU and for each spec + # iteration invokes this function only once + # (Look at multi-step-worker code) + is_fallback = num_steps == 1 + if not is_fallback: + # Since we do not broadcast data inside execute_model anymore, + # we need to figure out the best way to support TP > 1 in this + # case, because we will at least need to broadcast the sampled + # tokens to all workers. + if not self.is_driver_worker: + raise ValueError("TP1DraftModelRunner only supports TP=1.") + + # Sanity + if self.lora_config is not None: + raise ValueError("TP1DraftModelRunner has no support for LORA") + if self.prompt_adapter_config is not None: + raise ValueError("TP1DraftModelRunner has no support for " + "prompt_adapter_config") + if model_input.multi_modal_kwargs: + raise ValueError( + "TP1DraftModelRunner has no support for multi_modal_kwargs" + ) + else: + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + self.attn_state.begin_forward(model_input) + + # Detect exec mode + assert model_input.attn_metadata is not None + use_cuda_graph = False + if model_input.attn_metadata.num_prefills > 0: + # In this case, execute_model(..) was called directly + if num_steps > 1: + raise ValueError( + "execute_model(..) of draft_model_runner can be called " + "directly only with a single-step prefill") + else: + # We can skip CPU samples for spec token generation. + # (We do allow CPU samples for num_steps == 1 to support the + # fallback case, where supports_gpu_multi_step(..) does not pass) + model_input.sampling_metadata.skip_sampler_cpu_output = ( + not is_fallback) + + # Attn attr defines if we use cuda graphs + use_cuda_graph = model_input.attn_metadata.use_cuda_graph + + # Get model + if use_cuda_graph: + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = (self.graph_runners[model_input.virtual_engine] + [graph_batch_size]) + + if previous_hidden_states is not None: + hidden_states = torch.cat([ + previous_hidden_states, + torch.empty([ + graph_batch_size - previous_hidden_states.shape[0], + *previous_hidden_states.shape[1:] + ], + dtype=previous_hidden_states.dtype, + device=previous_hidden_states.device) + ]) + else: + hidden_states = None + else: + model_executable = self.model + hidden_states = previous_hidden_states + + outputs: List[SamplerOutput] = [] + for step in range(num_steps): + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + + kwargs = {"previous_hidden_states": hidden_states} \ + if previous_hidden_states is not None else {} + + # Run model + with set_forward_context(model_input.attn_metadata): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **kwargs, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Sample the next token. + outputs.append( + self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + )) + + # Prepare inputs for the next step + if step != num_steps - 1: + model_input = self._gpu_advance_step(model_input, outputs[-1]) + + return outputs diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py new file mode 100644 index 0000000..029f564 --- /dev/null +++ b/vllm/spec_decode/interfaces.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Set + +import torch + +from vllm.sequence import ExecuteModelRequest +from vllm.worker.worker_base import WorkerBase + + +@dataclass +class SpeculativeProposals: + """Datastructure used to represent proposal tokens from some proposer. It + also tracks how many speculative tokens each sequence has. + """ + + # Speculative proposal tokens. + proposal_token_ids: torch.Tensor + + # Probabilities of the proposal tokens according to the proposer. + proposal_probs: torch.Tensor + + # The valid length of each proposal; can be zero. + proposal_lens: torch.Tensor + + # A flag to mark that there's no available proposals + no_proposals: bool = False + + def __repr__(self): + return (f"SpeculativeProposals(" + f"proposal_token_ids={self.proposal_token_ids}, " + f"proposal_probs={self.proposal_probs.shape}, " + f"proposal_lens={self.proposal_lens})") + + +@dataclass +class SpeculativeScores: + """Datastructure used to represent the scores of speculative tokens + according to the scoring model. + """ + + # Probabilities of the speculative tokens according to the scoring model. + probs: torch.Tensor + + # Log-probabilities of the speculative tokens according to the scoring + # model. These values can be used to generate Logprob objects that are + # returned to the user. + logprobs: torch.Tensor + + # Token ids sampled from the scoring model. Used for speculative bonus + # tokens and also non-speculative normal decoding. + token_ids: torch.Tensor + + # Optional last hidden states from the scoring model. + hidden_states: Optional[torch.Tensor] = None + + def __repr__(self): + return (f"SpeculativeScores(" + f"probs={self.probs.shape}, " + f"token_ids={self.token_ids.shape})") + + +class SpeculativeProposer(ABC): + + @abstractmethod + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + # If set, this contains all sequence IDs that were assigned + # bonus tokens in their last forward pass. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + raise NotImplementedError + + +class SpeculativeScorer(ABC): + + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): + self._scorer_worker = scorer_worker + self._device = device + self._vocab_size = vocab_size + + @abstractmethod + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + raise NotImplementedError diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py new file mode 100644 index 0000000..0d233f3 --- /dev/null +++ b/vllm/spec_decode/medusa_worker.py @@ -0,0 +1,136 @@ +import weakref +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker import Worker + + +class MedusaWorker(NonLLMProposerWorkerBase, Worker): + """Worker for Medusa. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Lazy initialization list. + self._proposer: Top1Proposer + + def init_device(self): + super().init_device() + + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self): + pass + + def set_should_modify_greedy_probs_inplace(self): + pass + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For medusa worker, this indicator shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + seq_lens, query_lens = self._prepare_input_tensors( + seq_group_metadata_list) + + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory, generators) + + model_outputs = self.model_runner.model.generate_proposals( + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + sampling_metadata=sampling_metadata) + + return model_outputs, False + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[List[int], List[int]]: + if not seq_group_metadata_list: + return [], [] + + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + seq_lens.append(seq_len) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + query_lens.append(1) + + return seq_lens, query_lens + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MedusaWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MedusaWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MedusaWorker does not support beam search.") diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py new file mode 100644 index 0000000..89ccaba --- /dev/null +++ b/vllm/spec_decode/metrics.py @@ -0,0 +1,196 @@ +import time +from typing import Callable, Optional + +import msgspec +import torch + +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler) +from vllm.utils import is_pin_memory_available + + +class SpecDecodeWorkerMetrics( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """Dataclass holding metrics emitted from the spec decode worker. + """ + + # The empirical acceptance rate of the proposal method on a per-token basis. + # This is useful for evaluating how well the proposal method aligns with the + # scoring method. + draft_acceptance_rate: float + + # The empirical efficiency, measured as the number of tokens emitted by the + # system divided by the number of tokens that could be emitted by the system + # if the proposal method were perfect. + system_efficiency: float + + # The number of speculative tokens produced by the proposal method. + draft_tokens: int + + # The number of tokens emitted by the entire system. + emitted_tokens: int + + # The number of tokens accepted by the scoring model and verification + # routine, e.g. Llama2-70B and lossless rejection sampling. + # + # NOTE: Any token accepted by the verification routine is considered + # accepted (regardless of if the speculative prefix is also accepted). The + # user will usually see less accepted tokens. This metric is helpful when + # evaluating alignment of the proposal method with the scoring model. + accepted_tokens: int + + # The number of speculative tokens per sequence. + num_spec_tokens: int + + +Timer = Callable[[], float] + + +class AsyncMetricsCollector: + """Class which copies rejection/typical-acceptance sampler metrics + from the device to CPU on a non-default Torch stream. + """ + + def __init__(self, + spec_decode_sampler: SpecDecodeBaseSampler, + timer: Optional[Timer] = None, + collect_interval_s: float = 5.0): + self.spec_decode_sampler = spec_decode_sampler + self._timer = time.time if timer is None else timer + + self._rank: Optional[int] = None + + # We don't have a device set yet. + self._copy_stream: Optional[torch.cuda.Stream] = None + + self._in_flight_copy: Optional[torch.cuda.Event] = None + + pin_memory = is_pin_memory_available() + self._aggregate_num_accepted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_emitted_tokens = torch.tensor( + 0, dtype=torch.long, device="cpu", pin_memory=pin_memory) + self._aggregate_num_draft_tokens = 0 + + self._rejsample_metrics_collect_interval_s = collect_interval_s + self._last_metrics_collect_time = self._timer() + + def init_gpu_tensors(self, rank: int) -> None: + self._rank = rank + self._copy_stream = torch.cuda.Stream() + + def maybe_collect_rejsample_metrics( + self, k: int) -> Optional[SpecDecodeWorkerMetrics]: + + # If a copy was initiated in the previous call, collect and return. + if self._in_flight_copy is not None: + ready_event = self._in_flight_copy + self._in_flight_copy = None + return self._collect_rejsample_metrics(k, ready_event) + + # Otherwise, check if we should start a new copy. + if self._should_collect_rejsample_metrics(self._timer()): + assert self._in_flight_copy is None + self._in_flight_copy = self._copy_rejsample_metrics_async() + + return None + + def _should_collect_rejsample_metrics(self, now: float) -> bool: + """Return whether or not this iteration should print sampling + metrics. + """ + if self._rank != 0: + return False + + return now - self._last_metrics_collect_time >= self._rejsample_metrics_collect_interval_s # noqa: E501 + + def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: + """Copy rejection/typical-acceptance sampling metrics + (number of accepted tokens, etc) to CPU asynchronously. + + Returns a CUDA event recording when the copy is complete. + """ + assert self._copy_stream is not None + self._copy_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self._copy_stream): + self._aggregate_num_accepted_tokens.copy_( + self.spec_decode_sampler.num_accepted_tokens, + non_blocking=True) + self._aggregate_num_emitted_tokens.copy_( + self.spec_decode_sampler.num_emitted_tokens, non_blocking=True) + # Number of draft tokens is calculated on CPU, so no copy is + # required. + self._aggregate_num_draft_tokens = ( + self.spec_decode_sampler.num_draft_tokens) + + aggregate_metrics_ready = torch.cuda.Event() + aggregate_metrics_ready.record(self._copy_stream) + + return aggregate_metrics_ready + + def _collect_rejsample_metrics( + self, k: int, + ready_event: torch.cuda.Event) -> SpecDecodeWorkerMetrics: + """Create metrics object from statistics copied asynchronously. + + Args: + k: int. The number of speculative tokens; used to determine system + efficiency. + ready_event: torch.cuda.Event. The CUDA event recording when the + async GPU->CPU copy is complete. + """ + + ready_event.synchronize() + + # update time of last collection + self._last_metrics_collect_time = self._timer() + + accepted_tokens = self._aggregate_num_accepted_tokens.item() + emitted_tokens = self._aggregate_num_emitted_tokens.item() + draft_tokens = self._aggregate_num_draft_tokens + + max_num_emitted_tokens = self.get_max_num_emitted_tokens( + draft_tokens, k) + + if draft_tokens > 0: + draft_acceptance_rate = accepted_tokens / draft_tokens + else: + draft_acceptance_rate = float("nan") + + if max_num_emitted_tokens > 0: + system_efficiency = emitted_tokens / max_num_emitted_tokens + else: + system_efficiency = float("nan") + + return SpecDecodeWorkerMetrics( + num_spec_tokens=k, + draft_acceptance_rate=draft_acceptance_rate, + system_efficiency=system_efficiency, + accepted_tokens=accepted_tokens, + draft_tokens=draft_tokens, + emitted_tokens=emitted_tokens, + ) + + @staticmethod + def get_max_num_emitted_tokens(draft_tokens: int, k: int) -> int: + """Calculate the number of emitted tokens, assuming all tokens are + accepted. + + This is equal to the number of sequences that have been speculated on, + times (speculation len + 1). The +1 comes from the bonus token. + """ + # Determine the number of sequences that have been speculated on. Since + # the batch size can be variable, we divide by k. + assert draft_tokens % k == 0 + total_num_spec_seqs = draft_tokens // k + + # A single sequence may emit k accepted tokens and one bonus token in + # the best case. + num_emitted_per_seq_if_all_accepted = k + 1 + + # The max num of emitted tokens is the number of speculated sequences + # times the max emitted per seq. + return total_num_spec_seqs * num_emitted_per_seq_if_all_accepted diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py new file mode 100644 index 0000000..fc41bb8 --- /dev/null +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -0,0 +1,91 @@ +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase + + +class MLPSpeculatorWorker(NonLLMProposerWorkerBase, MultiStepWorker): + """Worker for MLPSpeculator models. + + Not currently compatible with LoRA or chunked prefill. + """ + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. MLPSpeculatorWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass to generate sample_len future tokens. + Returns the list of sampler output, one per layer, along with indicator + of whether torch tensor in sampler output need to be transposed in + latter sampler_output_to_torch logic. + + For mlp spec worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + (input_tokens, seq_lens, + query_lens) = self._prepare_input_tensors(seq_group_metadata_list) + + generators = self.model_runner.get_generators( + execute_model_req.finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, seq_lens, query_lens, self.device, + self.model_runner.pin_memory, generators) + + model_outputs = self.model_runner.model.generate_proposals( + input_ids=input_tokens, + previous_hidden_states=execute_model_req.previous_hidden_states. + hidden_states, + num_predict_tokens=sample_len, + sampling_metadata=sampling_metadata) + + assert len(model_outputs) == sample_len + + return model_outputs, True + + def _prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, List[int], List[int]]: + if not seq_group_metadata_list: + return torch.empty(0, device=self.device), [], [] + + input_tokens: List[int] = [] + seq_lens: List[int] = [] + query_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + is_prompt = seq_group_metadata.is_prompt + + for seq_data in seq_group_metadata.seq_data.values(): + seq_data_len = seq_data.get_len() + if is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min( + seq_data_len, + context_len + seq_group_metadata.token_chunk_size) + tokens = seq_data.get_token_ids()[context_len:seq_len] + seq_lens.append(seq_len) + input_tokens.extend(tokens) + query_lens.append(seq_len - context_len) + else: + seq_lens.append(seq_data_len) + input_tokens.append(seq_data.get_last_token_id()) + query_lens.append(1) + + input_tokens_tensor = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + return input_tokens_tensor, seq_lens, query_lens diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py new file mode 100644 index 0000000..f35a8a0 --- /dev/null +++ b/vllm/spec_decode/mqa_scorer.py @@ -0,0 +1,106 @@ +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) + +SeqId = int +TargetSeqId = int + + +class MQAScorer(SpeculativeScorer): + + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + target_seq_group_metadata_list = [] + target_seq_id_start = max( + get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 + all_proposal_tokens = proposals.proposal_token_ids.tolist() + all_proposal_lengths = proposals.proposal_lens.tolist() + for i, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data_dict = seq_group_metadata.seq_data + assert len(seq_data_dict) == 1 + seq_id = next(iter(seq_data_dict.keys())) + + seq_data: SequenceData = seq_data_dict[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + output_token_ids = seq_data.get_output_token_ids() + proposal_token_ids = all_proposal_tokens[ + i][:all_proposal_lengths[i]] + new_output_token_ids = [*output_token_ids, *proposal_token_ids] + + target_seq_id = target_seq_id_start + i + new_seq_data = SequenceData.from_seqs( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) + + # Ensure that the new sequence has at least one token + # because we only use mqa scorer in the decoding stage. + assert len(output_token_ids) >= 1 + new_seq_data_dict = {target_seq_id: new_seq_data} + + new_seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + target_seq_group_metadata_list.append(new_seq_group_metadata) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + + target_sampler_output = target_sampler_output[0] + + k = execute_model_req.num_lookahead_slots + bs = len(execute_model_req.seq_group_metadata_list) + target_token_ids = target_sampler_output.sampled_token_ids + target_probs = target_sampler_output.sampled_token_probs + target_logprobs = target_sampler_output.logprobs + # If all requests have the same number of query tokens, we can avoid + # the for loop to build output for better performance. + if min(all_proposal_lengths) == k: + bs, _ = proposals.proposal_token_ids.shape + all_tokens = target_token_ids.reshape(bs, k + 1) + all_probs = target_probs.reshape(bs, k + 1, self._vocab_size) + all_logprobs = target_logprobs.reshape(bs, k + 1, self._vocab_size) + else: + all_tokens = target_token_ids.new_full(size=(bs, k + 1), + fill_value=-1) + all_probs = target_probs.new_zeros(*all_tokens.shape, + self._vocab_size) + all_logprobs = target_logprobs.new_full(size=all_probs.shape, + fill_value=-float("inf")) + target_token_ids = target_token_ids.flatten() + start_loc = 0 + for i, proposed_len in enumerate(all_proposal_lengths): + output_len = proposed_len + 1 + end_loc = start_loc + output_len + all_tokens[ + i, :output_len] = target_token_ids[start_loc:end_loc] + all_probs[i, :output_len] = target_probs[start_loc:end_loc] + all_logprobs[ + i, :output_len] = target_logprobs[start_loc:end_loc] + start_loc = end_loc + + hidden_states = None + if target_sampler_output.hidden_states is not None: + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) + + return SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=hidden_states) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py new file mode 100644 index 0000000..4b53fbe --- /dev/null +++ b/vllm/spec_decode/multi_step_worker.py @@ -0,0 +1,366 @@ +import copy +import weakref +from typing import Dict, List, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, + SequenceGroupMetadata) +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer +from vllm.worker.worker import Worker + + +class MultiStepWorker(Worker, ProposerWorkerBase): + """The MultiStepWorker is equivalent to a Worker except that it allows + multiple forward passes in a single call, assuming the scheduler has + allocated enough space to store the additional KV. This reduces overhead + by invoking the scheduler less. + + The MultiStepWorker does not support cache swap operations, or beam search. + Cache swap operations do not require large modifications. On the other hand, + beam search requires memory allocations during sequence forks and thus + requires more thought for MultiStepWorker support. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Lazy initialization list. + self._proposer: SpeculativeProposer + + def init_device(self) -> None: + super().init_device() + + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + self.device, + self.vocab_size, + max_proposal_len=self.max_model_len, + ) + + def set_include_gpu_probs_tensor(self) -> None: + # Need include_gpu_probs_tensor for MultiStepWorker + self.model_runner.model.sampler.include_gpu_probs_tensor = True + + def set_should_modify_greedy_probs_inplace(self) -> None: + self.model_runner.model.sampler.should_modify_greedy_probs_inplace = ( + True) + + @torch.inference_mode() + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + """Run the model forward pass sample_len times. Returns the list of + sampler output, one per model forward pass, along with indicator of + whether torch tensor in sampler output need to be transposed in latter + sampler_output_to_torch logic. + + For multi step worker, this indicator shall be True. + """ + self._raise_if_unsupported(execute_model_req) + # Expand the batch for sequences with a bonus token. + # Perform a forward pass on the expanded batch and filter the + # response to retain only the original sequences' responses. + expanded_request, indices_of_seq_with_bonus_tokens =\ + self._expand_execute_model_request( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + # Run model sample_len times. + model_outputs: List[SamplerOutput] = [] + if isinstance( + self.model_runner, TP1DraftModelRunner + ) and self.model_runner.supports_gpu_multi_step(expanded_request): + # Here we run the draft_model_runner with multi-step prepare + # on the GPU directly + expanded_request.num_steps = sample_len + model_outputs = self.execute_model( + execute_model_req=expanded_request) + else: + # Here we run multi-step directly, with every step prepared + # on the CPU. + # TODO: Remove this branch once DraftModelRunner supports TP>1 + # and other restrictions that are part of DraftModelRunner's + # supports_gpu_multi_step(..) + for _ in range(sample_len): + model_output: List[SamplerOutput] = super().execute_model( + execute_model_req=expanded_request) + assert (len(model_output) == 1 + ), "composing multistep workers not supported" + model_output = model_output[0] + + self._append_new_tokens( + model_output, expanded_request.seq_group_metadata_list) + model_outputs.append(model_output) + + filtered_model_outputs = self._filter_model_output( + model_outputs, indices_of_seq_with_bonus_tokens) + return filtered_model_outputs, True + + @staticmethod + def _expand_execute_model_request( + execute_model_req: ExecuteModelRequest, + seq_with_bonus_token_in_last_step: set, + ) -> Tuple[ExecuteModelRequest, List[int]]: + """ + Expands the execute model request based on sequences with bonus + tokens. + + For each sequence with a bonus token, this method creates a new + sequence without the bonus token and adds it to the execute model + request. The original sequence groups are also retained. The indices + of the original sequence groups are returned for further processing. + + Args: + execute_model_req (ExecuteModelRequest): The original execute + model request. + seq_with_bonus_token_in_last_step (set): Set of sequence IDs that + contain bonus tokens. + + Returns: + Tuple[ExecuteModelRequest, List[int]]: The updated execute model + request with expanded sequences and a list of indices corresponding + to the original sequence groups. + """ + updated_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + updated_execute_model_req = execute_model_req.clone( + updated_seq_group_metadata_list) + indices_of_original_sequence_groups = [] + for seq_group in execute_model_req.seq_group_metadata_list: + seq_group_has_bonus_tokens = False + for seq_id, _ in seq_group.seq_data.items(): + # Identify sequences with bonus tokens in the sequence group. + if seq_id in seq_with_bonus_token_in_last_step: + seq_group_has_bonus_tokens = True + break + if seq_group_has_bonus_tokens: + #Create new sequences without the last bonus token. These new + # sequence have the same sequence id as the original sequence. + # We create a new sequence group and add them there. + updated_seq_group_without_bonus_token = \ + MultiStepWorker._copy_seq_metadata_excluding_last_token( + seq_group, seq_with_bonus_token_in_last_step) + updated_seq_group_metadata_list.append( + updated_seq_group_without_bonus_token) + # Add the original sequence group. + updated_seq_group_metadata_list.append( + MultiStepWorker._shallow_copy_seq_group_metadata(seq_group)) + # Record the index of the original sequence group. + indices_of_original_sequence_groups.append( + len(updated_seq_group_metadata_list) - 1) + + updated_execute_model_req.seq_group_metadata_list =\ + updated_seq_group_metadata_list + + if isinstance(updated_execute_model_req.previous_hidden_states, + HiddenStates): + updated_execute_model_req.previous_hidden_states\ + .expand_with_bonus_tokens(seq_with_bonus_token_in_last_step) + + return updated_execute_model_req, indices_of_original_sequence_groups + + @staticmethod + def _filter_model_output( + expanded_batch_outputs: List[SamplerOutput], + output_indices_to_retain: List[int]) -> List[SamplerOutput]: + """ + Filters the model output to include only the specified sequence + outputs. This method contracts the expanded batch output from the + model to retain the outputs of only those sequences indicated by the + provided indices. + + Args: + expanded_batch_output (List[SamplerOutput]): The expanded output + batch from the model. + output_indices_to_retain (List[int]): Indices of the model outputs + to retain. + + Returns: + List[SamplerOutput]: A list containing the filtered model + outputs for the specified indices. + """ + return [ + SamplerOutput( + outputs=[ + expanded_batch_output.outputs[i] + for i in output_indices_to_retain + ] if len(expanded_batch_output.outputs) > 0 else [], + sampled_token_probs=( + expanded_batch_output. + sampled_token_probs[output_indices_to_retain] + if expanded_batch_output.sampled_token_probs is not None + else None), + logprobs=( + expanded_batch_output.logprobs[output_indices_to_retain] + if expanded_batch_output.logprobs is not None else None), + sampled_token_ids=(expanded_batch_output. + sampled_token_ids[output_indices_to_retain] + if expanded_batch_output.sampled_token_ids + is not None else None)) + for expanded_batch_output in expanded_batch_outputs + ] + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: set, + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + @staticmethod + def _append_new_tokens( + model_output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done outside of the worker, but it is + required if the worker is to perform multiple forward passes. + """ + for seq_group_metadata, sequence_group_outputs in zip( + seq_group_metadata_list, model_output): + seq_group_metadata.is_prompt = False + + for seq_output in sequence_group_outputs.samples: + # NOTE: Beam search is not supported, so we can assume that + # parent_seq_id == seq_id. + seq = seq_group_metadata.seq_data[seq_output.parent_seq_id] + + token_id = seq_output.output_token + token_logprob = seq_output.logprobs[token_id] + + seq.append_token_id(token_id, token_logprob.logprob) + seq.update_num_computed_tokens(1) + + @staticmethod + def _shallow_copy_seq_group_metadata( + seq_group_metadata: SequenceGroupMetadata, ) -> SequenceGroupMetadata: + """Copy input data structures to remove side-effects when input data + structures are shared with other modules. + + Helpful when the vLLM scheduler runs in the same process as the worker. + The alternative is deep-copying (or other form of deep copy); this has + performance downsides. + """ + # Shallow-copy the SequenceGroupMetadata. This allows us to + # append tokens and change is_prompt without external side-effects. + # We must shallow-copy seq_group_metadata as is_prompt could change. + new_seq_group_metadata = copy.copy(seq_group_metadata) + + # We must shallow-copy seq_data as we will append token ids + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + new_seq_data[seq_id] = copy.copy(old_seq_data) + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:] + + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata + + @staticmethod + def _copy_seq_metadata_excluding_last_token( + seq_group_metadata: SequenceGroupMetadata, + seq_ids_to_copy: Set[int], + ) -> SequenceGroupMetadata: + """ + Creates a shallow copy of the given SequenceGroupMetadata, retaining + only the sequence IDs specified in seq_ids_to_copy. For each of these + sequence IDs, all output_token_ids except the last one are copied. + Sequence IDs not in seq_ids_to_copy are excluded from the copy. + + Parameters: + seq_group_metadata (SequenceGroupMetadata): The original sequence + group metadata. + seq_ids_to_copy (Set[int]): The set of sequence IDs to include in the + copy. + + Returns: + SequenceGroupMetadata: A shallow copy of the sequence group metadata + with the specified modifications. + """ + # Shallow-copy the SequenceGroupMetadata. + new_seq_group_metadata = copy.copy(seq_group_metadata) + # Shallow-copy seq_data and modify the output_token_ids. + new_seq_data: Dict[int, SequenceData] = {} + for seq_id, old_seq_data in seq_group_metadata.seq_data.items(): + if (seq_id in seq_ids_to_copy): + new_seq_data[seq_id] = copy.copy(old_seq_data) + # Copy all the output token ids except the last. + # Also reduce num_computed_tokens by 1 since we are not + # including the last output token. + # NOTE: num_computed_tokens is not directly used by the + # speculative decoding workers, as it is only relevant for + # chunked prefill, which is disabled for speculative decoding. + # However, to maintain consistency in num_computed_tokens, + # we update it here. + new_seq_data[seq_id].output_token_ids =\ + old_seq_data.output_token_ids[:-1] + new_seq_data[seq_id].update_num_computed_tokens(-1) + new_seq_group_metadata.seq_data = new_seq_data + return new_seq_group_metadata + + def _assert_enough_kv_space( + self, seq_group_metadata_list: List[SequenceGroupMetadata], + num_steps: int) -> None: + """Assert there are enough physical blocks per sequence to store the + current KV plus additional KV from num_steps tokens. + """ + assert self.model_runner.block_size is not None + for seq_group_metadata in seq_group_metadata_list: + # Only one seq_id is guaranteed because there is no beam search. + seq_id = list(seq_group_metadata.seq_data.keys())[0] + seq = seq_group_metadata.seq_data[seq_id] + + # After num_steps, the seq len will be the current seq len + # plus one token per step. + final_seq_len = seq.get_len() + num_steps + + # We will have final_seq_len - 1 KV because vLLM saves KV for a + # token in the iteration after the token was generated. + required_num_kv_slots = final_seq_len - 1 + + # The allocated number of kv slots is the number of allocated blocks + # times the number of slots of block. + number_physical_blocks = len( + seq_group_metadata.block_tables[seq_id]) + allocated_kv_slots = (number_physical_blocks * + self.model_runner.block_size) + + if required_num_kv_slots > allocated_kv_slots: + request_id = seq_group_metadata.request_id + raise ValueError( + "The worker attempted to run " + f"{num_steps} times but found insufficient KV space for " + f"{request_id=} {seq_id=}. ({allocated_kv_slots=} " + f"{required_num_kv_slots=}).") + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """MultiStepWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "MultiStepWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "MultiStepWorker does not support beam search.") diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py new file mode 100644 index 0000000..36e5e17 --- /dev/null +++ b/vllm/spec_decode/ngram_worker.py @@ -0,0 +1,169 @@ +import weakref +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase +from vllm.spec_decode.top1_proposer import Top1Proposer + + +class NGramWorker(NonLLMProposerWorkerBase): + """NGramWorker provides a light drafter without need for model. + + Current NGramWorker only implements prompt lookup decoding, + and in future we may also do RAG type drafter and other scenarios + which don't rely on LLM model to give proposals. + """ + + def __init__(self, *args, **kwargs): + # Get local_rank/vocab_size from kwargs attribute + self.local_rank = kwargs["local_rank"] + self.vocab_size = kwargs["model_config"].get_vocab_size() + + # Lazy initialization list. + self._proposer: Top1Proposer + + def set_ngram_window_size(self, ngram_prompt_lookup_min: int, + ngram_prompt_lookup_max: int): + # Search valid candidate window between + # ngram_prompt_lookup_min/ngram_prompt_lookup_max + self.ngram_prompt_lookup_max = ngram_prompt_lookup_max + self.ngram_prompt_lookup_min = ngram_prompt_lookup_min + + def init_device(self): + self.device = torch.device(f"cuda:{self.local_rank}") + self.load_model = lambda *args, **kwargs: None + + # Current NGramWorker only supports Top1Proposer + self._proposer = Top1Proposer( + weakref.proxy(self), # type: ignore[arg-type] + device=self.device, + vocab_size=self.vocab_size, + ) + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[Optional[List[Optional[SamplerOutput]]], bool]: + """NGram match algo to pick proposal candidate. Returns the list of + sampler output, one per SequenceGroupMetadata. + + For ngram worker, we already done needed transposed internal, so the + indicator pass to sampler_output_to_torch shall be False. + """ + self._raise_if_unsupported(execute_model_req) + + has_spec_out = False + token_id_list: List[Optional[torch.Tensor]] = [] + token_prob_list: List[Optional[torch.Tensor]] = [] + for idx, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data = next(iter(seq_group_metadata.seq_data.values())) + + input_ids = torch.as_tensor(seq_data.get_token_ids(), + dtype=torch.long, + device=self.device) + input_length = seq_data.get_len() + + for ngram_size in range( + min(self.ngram_prompt_lookup_max, input_length - 1), + self.ngram_prompt_lookup_min - 1, + -1, + ): + ngram_tensor = input_ids[-ngram_size:] + if ngram_size == 1: + # Do not match itself and do not use unfold and all + matches = (input_ids[:-1] == ngram_tensor) + else: + windows = input_ids.unfold(dimension=0, + size=ngram_size, + step=1) + # Do not match itself + matches = (windows[:-1] == ngram_tensor).all(dim=-1) + + # first_match includes "values" (bool), indicating whether + # the match is found, and "indices", indicating the index + # of the first match. + # Note that "first_match.values.item()" triggers GPU-CPU + # sync so it is a bit inefficient, but we have not found + # a better way to do this. + first_match = matches.max(dim=-1) + if first_match.values.item(): + proposal_start_idx = first_match.indices.add_(ngram_size) + spec_indices = ( + proposal_start_idx).repeat(sample_len) + torch.arange( + sample_len, device=self.device) + spec_indices.clamp_(max=input_ids.shape[-1] - 1) + res = input_ids.gather(dim=-1, index=spec_indices) + token_id_list.append(res) + token_prob_list.append( + torch.nn.functional.one_hot( + res, + num_classes=self.vocab_size).to(torch.float32)) + has_spec_out = True + break + else: + token_id_list.append(None) + token_prob_list.append(None) + + if not has_spec_out: + return None, False + + outputs: List[Optional[SamplerOutput]] = [] + for idx in range(len(execute_model_req.seq_group_metadata_list)): + if token_id_list[idx] is None: + outputs.append(None) + else: + outputs.append( + SamplerOutput( + outputs=None, + sampled_token_probs=token_prob_list[idx], + logprobs=torch.zeros((sample_len, self.vocab_size), + dtype=torch.float32, + device=self.device), + sampled_token_ids=token_id_list[idx], + )) + + return outputs, False + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + # Unused parameter. NGramWorker does not use the KV Cache and + # therefore does not need this parameter. + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + return self._proposer.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def _raise_if_unsupported( + self, + execute_model_req: ExecuteModelRequest, + ) -> None: + """NGramWorker does not yet implement support for cache swap + operations or beam search. + """ + if any([ + execute_model_req.blocks_to_swap_in, + execute_model_req.blocks_to_swap_out, + execute_model_req.blocks_to_copy + ]): + raise NotImplementedError( + "NGramWorker does not support cache operations") + + if any( + len(seq_group_metadata.seq_data.keys()) != 1 + for seq_group_metadata in + execute_model_req.seq_group_metadata_list): + raise NotImplementedError( + "NGramWorker does not support beam search.") diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py new file mode 100644 index 0000000..28a5375 --- /dev/null +++ b/vllm/spec_decode/proposer_worker_base.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, Set, Tuple + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposer +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class ProposerWorkerBase(LoraNotSupportedWorkerBase, SpeculativeProposer): + """Interface for proposer workers""" + + @abstractmethod + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + # A set containing all sequence IDs that were assigned bonus tokens + # in their last forward pass. This set is used to backfill the KV cache + # with the key-value pairs of the penultimate token in the sequences. + # This parameter is only used by the MultiStepWorker, which relies on + # the KV cache for token generation. It is not used by workers that + # do not utilize the KV cache. + seq_ids_with_bonus_token_in_last_step: Set[int] + ) -> Tuple[Optional[List[SamplerOutput]], bool]: + raise NotImplementedError + + def set_include_gpu_probs_tensor(self) -> None: + """Implementation optional""" + pass + + def set_should_modify_greedy_probs_inplace(self) -> None: + """Implementation optional""" + pass + + +class NonLLMProposerWorkerBase(ProposerWorkerBase, ABC): + """Proposer worker which does not use a model with kvcache""" + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """get_spec_proposals is used to get the proposals""" + return [] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """This is never called on the proposer, only the target model""" + raise NotImplementedError + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + pass + + def get_cache_block_size_bytes(self) -> int: + return 0 diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py new file mode 100644 index 0000000..8896b7d --- /dev/null +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -0,0 +1,161 @@ +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.distributed.parallel_state import (get_tp_group, + init_model_parallel_group, + patch_tensor_parallel_group) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.interfaces import SpeculativeProposals +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase + +logger = init_logger(__name__) + + +class SmallerTpProposerWorker(ProposerWorkerBase): + """Class which allows a speculative draft model to run with smaller tensor + parallel degree than target model. + This reduces the communication overhead of small draft models. + + To implement this feature, this class differs behavior based on is_dummy + flag, where dummy means worker that does not participate draft generation. + Participating workers use a smaller tp group by patching vLLM's tensor + parallel group temporarily during forward passes of draft models. + """ + + @classmethod + def maybe_wrap_worker(cls, worker, draft_tensor_parallel_size: int, + target_tensor_parallel_size: int): + """Wrap the worker in a SmallerTpProposerWorker if necessary. + """ + if draft_tensor_parallel_size == target_tensor_parallel_size: + return worker + + # gpu ranks that will generate draft tokens together + draft_ranks = list(range(draft_tensor_parallel_size)) + + logger.info("Wrapping {%s} in {%s}", type(worker), cls) + return cls(worker, draft_ranks) + + def __init__(self, worker: MultiStepWorker, draft_ranks: List[int]): + """Create a SmallerTpProposerWorker. + + Args: + worker (MultiStepWorker): an actual worker wrapped with this class + draft_ranks (List[int]): if this value is given, only the GPU ranks + written in this value participate in draft generation + """ + self._worker = worker + self._draft_ranks = draft_ranks + + # init during init_device + self._is_dummy = False + self._tp_group = None + + def _patch_tensor_parallel_group(self): + """Temporarily patch the global tp group state with its own tp group + state. + """ + return patch_tensor_parallel_group(self._tp_group) + + def init_device(self) -> None: + self._is_dummy = get_tp_group().rank not in self._draft_ranks + + # dummy workers do nothing + if self._is_dummy: + return + + # creates tp process group containing only a subset of gpu ranks + local_rank = get_tp_group().local_rank + tp_backend = torch.distributed.get_backend(get_tp_group().device_group) + self._tp_group = init_model_parallel_group([self._draft_ranks], + local_rank, tp_backend) + + with self._patch_tensor_parallel_group(): + self._worker.init_device() + + def set_include_gpu_probs_tensor(self) -> None: + if self._is_dummy: + return + + # Need include_gpu_probs_tensor for multi_step_worker + self._worker.set_include_gpu_probs_tensor() + + def set_should_modify_greedy_probs_inplace(self) -> None: + if self._is_dummy: + return + + self._worker.set_should_modify_greedy_probs_inplace() + + def load_model(self) -> None: + if self._is_dummy: + return + + with self._patch_tensor_parallel_group(): + self._worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + if self._is_dummy: + # this case is not used now + return -1, -1 + + with self._patch_tensor_parallel_group(): + return self._worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + if self._is_dummy: + return + + with self._patch_tensor_parallel_group(): + self._worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def sampler_output( + self, + execute_model_req: ExecuteModelRequest, + sample_len: int, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> Tuple[List[SamplerOutput], bool]: + # Do not check _is_dummy, as it's always called by get_spec_proposals + return self._worker.sampler_output( + execute_model_req, sample_len, + seq_ids_with_bonus_token_in_last_step) + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Produce speculations given an input batch of sequences. The number of + speculative tokens per sequence is determined by max_proposal_len. + """ + if self._is_dummy: + return SpeculativeProposals(None, None, None) + + with self._patch_tensor_parallel_group(): + return self._worker.get_spec_proposals( + execute_model_req, seq_ids_with_bonus_token_in_last_step) + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + if self._is_dummy: + return [] + + with self._patch_tensor_parallel_group(): + return self._worker.execute_model(execute_model_req) + + def get_cache_block_size_bytes(self) -> int: + if self._is_dummy: + # by returning zero, target worker can use the entire kv cache space + return 0 + + return self._worker.get_cache_block_size_bytes() + + @property + def vocab_size(self) -> int: + return self._worker.vocab_size diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py new file mode 100644 index 0000000..50d2767 --- /dev/null +++ b/vllm/spec_decode/spec_decode_worker.py @@ -0,0 +1,1074 @@ +from collections import defaultdict +from functools import cached_property +from typing import Any, Dict, List, Optional, Set, Tuple, Type + +import torch + +from vllm.config import ParallelConfig, SpeculativeConfig +from vllm.distributed.communication_op import broadcast_tensor_dict +from vllm.logger import init_logger +from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.spec_decode_base_sampler import ( + SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) +from vllm.model_executor.layers.typical_acceptance_sampler import ( + TypicalAcceptanceSampler) +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, + CompletionSequenceGroupOutput, ExecuteModelRequest, + HiddenStates, SequenceGroupMetadata, + get_all_seq_ids_and_request_ids) +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) +from vllm.spec_decode.medusa_worker import MedusaWorker +from vllm.spec_decode.metrics import AsyncMetricsCollector +from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.mqa_scorer import MQAScorer +from vllm.spec_decode.multi_step_worker import MultiStepWorker +from vllm.spec_decode.ngram_worker import NGramWorker +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.smaller_tp_proposer_worker import SmallerTpProposerWorker +from vllm.spec_decode.target_model_runner import TargetModelRunner +from vllm.spec_decode.util import (Timer, create_logprobs_output, + create_sequence_group_output, + get_all_num_logprobs, + get_sampled_token_logprobs, nvtx_range, + split_batch_by_proposal_len) +from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase, WorkerBase + +logger = init_logger(__name__) + + +def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": + """Helper method that is the entrypoint for Executors which use + WorkerWrapper. It constructs a SpecDecodeWorker from the speculative config. + """ + assert "speculative_config" in kwargs + speculative_config: SpeculativeConfig = kwargs.get("speculative_config") + assert speculative_config is not None + + draft_worker_kwargs = kwargs.copy() + + kwargs["model_runner_cls"] = TargetModelRunner + target_worker = Worker(*args, **kwargs) + # Set the disable_logprobs variable in the TargetModelRunner instance + # as per its value specified in the SpeculativeConfig. + target_worker.model_runner.disable_logprobs =\ + speculative_config.disable_logprobs + + # Override draft-model specific worker args. + draft_worker_kwargs.update( + model_config=speculative_config.draft_model_config, + parallel_config=speculative_config.draft_parallel_config, + ngram_prompt_lookup_max=speculative_config.ngram_prompt_lookup_max, + ngram_prompt_lookup_min=speculative_config.ngram_prompt_lookup_min, + # TODO allow draft-model specific load config. + #load_config=load_config, + ) + + spec_decode_worker = SpecDecodeWorker.create_worker( + scorer_worker=target_worker, + draft_worker_kwargs=draft_worker_kwargs, + disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, + disable_by_batch_size=speculative_config. + speculative_disable_by_batch_size, + draft_token_acceptance_method=speculative_config. + draft_token_acceptance_method, + typical_acceptance_sampler_posterior_threshold=speculative_config. + typical_acceptance_sampler_posterior_threshold, + typical_acceptance_sampler_posterior_alpha=speculative_config. + typical_acceptance_sampler_posterior_alpha, + disable_logprobs=speculative_config.disable_logprobs, + disable_log_stats=speculative_config.disable_log_stats, + ) + + return spec_decode_worker + + +# Reminder: Please update docs/source/serving/compatibility_matrix.rst +# If the feature combo become valid +class SpecDecodeWorker(LoraNotSupportedWorkerBase): + """Worker which implements speculative decoding. + + Speculative decoding reduces decoding per-token latency by using a proposal + method, such as a small draft model, to speculate ahead of a larger LLM. The + probabilities of the speculative tokens are then determined by the larger + LLM, after which some verification routine determines which (if any) of the + speculative tokens are accepted by the larger LLM. + + See https://github.com/vllm-project/vllm/pull/2188 and + https://github.com/vllm-project/vllm/pull/3103 for more info. + + The current implementation has the following limitations: + * Only draft-model proposal is implemented (contributions for more forms are + welcome!). + * Only top-1 proposal and scoring are implemented. Tree-attention is left as + future work. + * All sequences in a batch must have the same proposal length, or zero. This + can be improved by having per-sequence speculation in the future. + * The scoring forward pass is done without an MQA kernel, which is + suboptimal especially as the batch size, proposal length, and sequence + lengths grow. Contributions to add a MQA scoring are welcome once + correctness tests pass. + More info here https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit. + """ + + @classmethod + def create_worker( + cls, + scorer_worker: Worker, + draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, + disable_by_batch_size: Optional[int], + draft_token_acceptance_method: str, + typical_acceptance_sampler_posterior_threshold: float, + typical_acceptance_sampler_posterior_alpha: float, + disable_logprobs: bool, + disable_log_stats: bool, + ) -> "SpecDecodeWorker": + + allow_zero_draft_token_step = True + ngram_prompt_lookup_max = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_max")) + ngram_prompt_lookup_min = ( + draft_worker_kwargs.pop("ngram_prompt_lookup_min")) + if ngram_prompt_lookup_max > 0: + proposer_worker = NGramWorker(**draft_worker_kwargs) + proposer_worker.set_ngram_window_size(ngram_prompt_lookup_min, + ngram_prompt_lookup_max) + else: + draft_parallel_config: ParallelConfig = draft_worker_kwargs[ + 'parallel_config'] + draft_tp = draft_parallel_config.tensor_parallel_size + target_tp = scorer_worker.parallel_config.tensor_parallel_size + + if draft_worker_kwargs[ + "model_config"].hf_config.model_type == "mlp_speculator": + proposer_worker = MLPSpeculatorWorker(**draft_worker_kwargs) + elif draft_worker_kwargs[ + "model_config"].hf_config.model_type == "medusa": + proposer_worker = MedusaWorker(**draft_worker_kwargs) + else: + if draft_tp == 1: + draft_worker_kwargs[ + "model_runner_cls"] = TP1DraftModelRunner + else: + if draft_worker_kwargs[ + "model_config"].hf_config.model_type == "eagle": + raise NotImplementedError( + "EAGLE does not support TP > 1 yet") + + allow_zero_draft_token_step = False + proposer_worker = MultiStepWorker(**draft_worker_kwargs) + + proposer_worker = SmallerTpProposerWorker.maybe_wrap_worker( + proposer_worker, draft_tp, target_tp) + + logger.info("Configuring SpecDecodeWorker with proposer=%s", + type(proposer_worker)) + + spec_decode_sampler: SpecDecodeBaseSampler = None + if draft_token_acceptance_method == "rejection_sampler": + spec_decode_sampler = RejectionSampler() + elif draft_token_acceptance_method == "typical_acceptance_sampler": + spec_decode_sampler = TypicalAcceptanceSampler( + posterior_threshold=\ + typical_acceptance_sampler_posterior_threshold, + posterior_alpha=typical_acceptance_sampler_posterior_alpha, + ) + logger.info( + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name( + ) != "flash-attn": + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if "model_config" in draft_worker_kwargs and \ + draft_worker_kwargs["model_config"].max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") + + return SpecDecodeWorker( + proposer_worker, + scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, + disable_logprobs=disable_logprobs, + disable_log_stats=disable_log_stats, + disable_by_batch_size=disable_by_batch_size, + spec_decode_sampler=spec_decode_sampler, + allow_zero_draft_token_step=allow_zero_draft_token_step) + + def __init__( + self, + proposer_worker: ProposerWorkerBase, + scorer_worker: WorkerBase, + spec_decode_sampler: SpecDecodeBaseSampler, + disable_mqa_scorer: bool = False, + disable_logprobs: bool = False, + disable_log_stats: bool = False, + metrics_collector: Optional[AsyncMetricsCollector] = None, + disable_by_batch_size: Optional[int] = None, + allow_zero_draft_token_step: Optional[bool] = True, + ): + """ + Create a SpecDecodeWorker. + + Args: + proposer_worker: A worker that can produce speculative tokens for + sequences. + scorer_worker: A worker that produces probabilities of speculative + tokens according to some base model. Typically a vanilla vLLM + Worker. + spec_decode_sampler: A Torch module used to perform acceptance + sampling of the draft tokens in the verification step of + speculative decoding. Currently we support two different + types of sampler namely RejectionSampler and + TypicalAcceptanceSampler. 'spec_decode_sampler' is either an + instance of RejectionSampler or TypicalAcceptanceSampler. + disable_mqa_scorer: If set to True, disable the MQA scorer and use + the BatchExpansionTop1Scorer instead. + disable_logprobs: If set to True, token log probabilities will + not be output in both the draft worker and the target worker. + If set to False, log probabilities will be output by both. + disable_log_stats: If set to True, disable periodic printing of + speculative stage times. + disable_by_batch_size: If the batch size is larger than this, + disable speculative decoding for new incoming requests. + metrics_collector: Helper class for collecting metrics; can be set + for testing purposes. + allow_zero_draft_token_step: whether to allow a step where the draft + model generates no draft token; should disallow when the tp of + draft model is larger than 1 (TODO: #5814) + """ + self.proposer_worker = proposer_worker + self.scorer_worker = scorer_worker + scorer_runner = getattr(self.scorer_worker, "model_runner", None) + self.generators = scorer_runner.get_generators( + ) if scorer_runner else None + self.disable_by_batch_size = disable_by_batch_size or float("inf") + self.spec_decode_sampler = spec_decode_sampler + self._allow_zero_draft_token_step = allow_zero_draft_token_step + self._metrics = AsyncMetricsCollector( + self.spec_decode_sampler + ) if metrics_collector is None else metrics_collector + # Tracks the sequence IDs that received a bonus token ID in + # their last forward pass. Needed only if KV cache is being + # used for token generation such as in the case of MultiStepWorker. + self._seq_with_bonus_token_in_last_step: Set[int] = set() + # Tracks the currently active request ids and the sequence IDs + # corresponding to them + self._request_id_seq_id_mapping: Dict[str, Set[int]] = defaultdict(set) + # Tracks if the proposer worker uses the KV cache or not. + + self.probs_dtype = self.spec_decode_sampler.probs_dtype + self.token_id_dtype = self.spec_decode_sampler.token_id_dtype + # Lazy initialization. + self.scorer: SpeculativeScorer + self.disable_mqa_scorer = disable_mqa_scorer + + # Hidden states from target model to pass to proposer + # in the subsequent step. + self.previous_hidden_states: Optional[HiddenStates] = None + self._disable_logprobs = disable_logprobs + self._disable_log_stats = disable_log_stats + + def init_device(self) -> None: + """Initialize both scorer and proposer models. + """ + # The scorer worker model is initialized first in case the proposer + # model has a smaller TP degree than the target worker. + self.scorer_worker.init_device() + self.proposer_worker.init_device() + + # NOTE(cade): load_model is not part of the WorkerBase interface. + self.scorer_worker.load_model() + self.proposer_worker.load_model() + + self._metrics.init_gpu_tensors(self.rank) + self.spec_decode_sampler.init_gpu_tensors(self.rank) + + scorer_cls: Type[SpeculativeScorer] + if self.disable_mqa_scorer: + scorer_cls = BatchExpansionTop1Scorer + logger.info("[Speculative Decoding] Use batch " + "expansion for scoring proposals.") + else: + scorer_cls = MQAScorer + logger.info( + "[Speculative Decoding] Use MQA scorer for scoring proposals.") + + self.scorer = scorer_cls(scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) + + self._configure_model_sampler_for_spec_decode() + + def load_model(self, *args, **kwargs): + pass + + def _configure_model_sampler_for_spec_decode(self): + """Configure model sampler to emit GPU tensors. This allows spec decode + to keep data on device without transferring to CPU and serializing, + which significantly reduces overhead of sampling during verification. + + NOTE(cade): This breaks abstraction boundaries pretty badly. The better + design is to have the "move to CPU and serialize" sampling decision be + done outside of the model/sampler; this way the "last-mile" worker + object which interfaces with the scheduler can serialize and incur the + performance hit as necessary. This allows us to run the worker several + iterations in a row without incurring the "move to CPU and serialize" + performance penalty. + + Since this requires a large change to vLLM, we defer it to later and + temporarily accept this broken abstraction boundary. + + NOTE(cade): This will require a special check if the proposer worker + does not have a sampler (e.g. ngram speculation). + """ + (self.scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor + ) = True + (self.scorer_worker.model_runner.model.sampler. + should_modify_greedy_probs_inplace) = True + self.proposer_worker.set_include_gpu_probs_tensor() + self.proposer_worker.set_should_modify_greedy_probs_inplace() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of cache blocks to use. + + This is done by profiling the scorer model (which is typically the + larger of the two). Then the total memory which would be used by the + scorer cache is divided evenly between the proposer and scorer model KV, + such that the number of blocks is equal in both KV caches. + """ + num_gpu_blocks, num_cpu_blocks = ( + self.scorer_worker.determine_num_available_blocks()) + + scorer_cache_block_size_bytes = ( + self.scorer_worker.get_cache_block_size_bytes()) + proposer_cache_block_size_bytes = ( + self.proposer_worker.get_cache_block_size_bytes()) + + new_num_gpu_blocks = split_num_cache_blocks_evenly( + scorer_cache_block_size_bytes, proposer_cache_block_size_bytes, + num_gpu_blocks) + return new_num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the cache engine of the scorer and proposer workers. + """ + self.scorer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Perform speculative decoding on the input batch. + """ + if self.rank != self._driver_rank: + self._run_non_driver_rank() + return [] + + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] + + self._track_finished_requests(execute_model_req) + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + + # Speculative decoding is disabled in the following cases: + # 1. Prefill phase: Speculative decoding is not + # used during the prefill phase. + # 2. Auto-disable enabled: The running queue size exceeds + # the specified threshold. + # 3. No request: There are no requests in the batch, or + # none of the requests in the batch have spec decoding enabled. + # In any of these cases, the proposer and scorer workers + # are called normally. + no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( + sgm.num_speculative_tokens == 0 + for sgm in execute_model_req.seq_group_metadata_list) + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform them. + + # no_spec is used to signal non-driver worker about prefill vs decode + # stage. This is needed to ensure that order of execution of proposer + # and scorer is same in both driver and non-driver workers (i.e., + # scorer -> proposer for prefill and proposer -> scorer in decode). This + # order is needed to support models like EAGLE that take scorer states + # as inputs. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + no_spec=no_spec, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + if no_spec: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + + def _should_disable_all_speculation( + self, execute_model_req: ExecuteModelRequest) -> bool: + # When the batch size is too large, disable speculative decoding + # to stop trading off throughput for latency. + return (execute_model_req.running_queue_size >= + self.disable_by_batch_size) + + def _maybe_disable_speculative_tokens( + self, disable_all_speculation: bool, + seq_group_metadata_list: List[SequenceGroupMetadata]) -> None: + if not disable_all_speculation: + return + + for seq_group_metadata in seq_group_metadata_list: + # Once num_speculative_tokens is set to 0, the spec decode + # of this request will be disabled forever. + # TODO(comaniac): We currently store spec decoding specific + # state in the global data structure, but we should maintain + # this state within spec decode worker. + seq_group_metadata.num_speculative_tokens = 0 + + def _serialize_sampler_output_no_logprobs( + self, execute_model_req: ExecuteModelRequest, + sampler_output: SamplerOutput) -> SamplerOutput: + """ + Creates and returns a `SamplerOutput` with only the token IDs being + serialized to CPU and populated in `CompletionSequenceGroupOutput`. + All other parameters in `CompletionSequenceGroupOutput` related to log + probabilities are skipped. + + Args: + execute_model_req (ExecuteModelRequest): The model request that + was executed. + sampler_output (SamplerOutput): The output from the sampler with + only GPU tensors populated. + + Returns: + SamplerOutput: A new `SamplerOutput` instance containing a list of + `CompletionSequenceGroupOutput` objects with only token IDs + populated. + """ + seq_output_prompt_logprobs = [ + seq.is_prompt and seq.sampling_params.prompt_logprobs is not None + and seq.sampling_params.prompt_logprobs > 0 + for seq in execute_model_req.seq_group_metadata_list + ] + # ignore slots for prompt tokens that are filled with INVALID_TOKEN_ID + sampled_token_ids_list = (sampler_output.sampled_token_ids[torch.where( + # subtracting is faster than testing for equality + sampler_output.sampled_token_ids - VLLM_INVALID_TOKEN_ID)[0]] \ + if any(seq_output_prompt_logprobs) else \ + sampler_output.sampled_token_ids).tolist() + + seq_data_entries = ( + (seq_id, seq_data) for sg in \ + execute_model_req.seq_group_metadata_list \ + for seq_id, seq_data in sg.seq_data.items() + ) + completion_seq_group_output_list: List[ + CompletionSequenceGroupOutput] = [] + for index, ((seq_id, seq_data), needs_prompt_logprobs) in \ + enumerate(zip(seq_data_entries, seq_output_prompt_logprobs)): + if needs_prompt_logprobs: + prompt_token_ids = seq_data.get_prompt_token_ids() + prompt_logprobs = [ + create_logprobs_output( + token_id=p_token_id, + token_id_logprob_rank=-1, + token_id_logprob=0.0, + topk_token_ids=[], + topk_logprobs=[], + ) + # no prompt logprobs for the first token + for p_token_id in prompt_token_ids[1:] + ] + else: + prompt_logprobs = None + + completion_seq_group_output_list.append( + create_sequence_group_output( + token_id=sampled_token_ids_list[index][0], + token_id_logprob_rank=-1, + token_id_logprob=0.0, + seq_id=seq_id, + topk_token_ids=[], + topk_logprobs=[], + prompt_logprobs=prompt_logprobs)) + return SamplerOutput(outputs=completion_seq_group_output_list) + + @nvtx_range("spec_decode_worker._run_no_spec") + def _run_no_spec(self, execute_model_req: ExecuteModelRequest, + skip_proposer: bool) -> List[SamplerOutput]: + """Run a single generation step without any speculation. The input is + sent to the proposer and scorer model so that the KV cache is consistent + between the two. When skip_proposer is True, the proposer model is + not called, meaning that the kv-cache in proposer for requests is not + updated, so they cannot enable spec decode in the rest decoding. + """ + + sampler_output = self.scorer_worker.execute_model(execute_model_req) + assert len(sampler_output) == 1 + sampler_output = sampler_output[0] + + # Store hidden states from target model execution. + hidden_states = sampler_output.hidden_states + if hidden_states is not None: + # remove hidden_states for prompt tokens + if any(seq.is_prompt + for seq in execute_model_req.seq_group_metadata_list): + hidden_states = hidden_states[ + torch.where(sampler_output.sampled_token_ids - + VLLM_INVALID_TOKEN_ID)[0]] + if self.previous_hidden_states is None: + self.previous_hidden_states = HiddenStates( + hidden_states, execute_model_req.seq_group_metadata_list) + else: + self.previous_hidden_states.update( + hidden_states, execute_model_req.seq_group_metadata_list) + + if not skip_proposer: + # We prepare the prefill hidden states here so that there no + # additional complexity in worker for spec_decode vs non_spec_decode + # flow and execute_model doesn't need additional modifications. + execute_model_req.previous_hidden_states = \ + prepare_prefill_hidden_states( + sampler_output.prefill_hidden_states) + + self.proposer_worker.execute_model(execute_model_req) + + sampler_output_to_return = (self._serialize_sampler_output_no_logprobs( + execute_model_req=execute_model_req, sampler_output=sampler_output) + if self._disable_logprobs else + sampler_output) + + # Clear device tensors from sampler output. This reduces communication + # overhead when the engine runs in a different process than the workers. + sampler_output.sampled_token_probs = None + sampler_output.sampled_token_ids = None + sampler_output.logprobs = None + return [sampler_output_to_return] + + def _run_non_driver_rank(self) -> bool: + """Run proposer and verifier model in non-driver workers. This is used + for both speculation cases (num_lookahead_slots>0) and non-speculation + cases (e.g. prefill). + + Returns True if there are remaining sequences to process. + """ + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] + + # In case of prefill, scorer_worker has to be run before proposer so + # that the hidden states can be propagated to proposer when needed. + if data["no_spec"]: + self.scorer_worker.execute_model() + + if not data["disable_all_speculation"]: + # Even if num_lookahead_slots is zero, we want to run the + # proposer model as it may have KV. + # + # We run the proposer once per lookahead slot. In the future we + # should delegate how many times it runs to the proposer. + for _ in range(max(num_lookahead_slots, 1)): + self.proposer_worker.execute_model() + + if not data["no_spec"]: + self.scorer_worker.execute_model() + + return True + + @nvtx_range("spec_decode_worker._run_speculative_decoding_step") + def _run_speculative_decoding_step( + self, execute_model_req: ExecuteModelRequest, + num_lookahead_slots: int) -> List[SamplerOutput]: + """Execute a single step of speculative decoding. + + This invokes the proposer worker to get k speculative tokens for each + sequence, then scores each speculative token using the scoring worker. + + Returns a list of SamplerOutput, each containing a single token per + sequence. + """ + assert num_lookahead_slots == execute_model_req.num_lookahead_slots + + # Pass last hidden states from target model to proposer + execute_model_req.previous_hidden_states = self.previous_hidden_states + self.previous_hidden_states = None + + with Timer() as proposal_timer: + # Generate proposals using draft worker. + proposals = self.proposer_worker.get_spec_proposals( + execute_model_req, self._seq_with_bonus_token_in_last_step) + + if not self._allow_zero_draft_token_step and proposals.no_proposals: + #TODO: Fix it #5814 + raise RuntimeError("Cannot handle cases where distributed draft " + "workers generate no tokens") + + execute_model_req.previous_hidden_states = None + + with Timer() as scoring_timer: + proposal_scores = self.scorer.score_proposals( + execute_model_req, + proposals, + ) + + with Timer() as verification_timer: + accepted_token_ids, target_logprobs = self._verify_tokens( + execute_model_req.seq_group_metadata_list, proposal_scores, + proposals, execute_model_req.num_lookahead_slots) + + stage_times = (proposal_timer.elapsed_time_ms / num_lookahead_slots, + scoring_timer.elapsed_time_ms, + verification_timer.elapsed_time_ms) + + return self._create_output_sampler_list( + execute_model_req.seq_group_metadata_list, + accepted_token_ids, + target_logprobs=target_logprobs, + k=execute_model_req.num_lookahead_slots, + stage_times=stage_times) + + @nvtx_range("spec_decode_worker._verify_tokens") + def _verify_tokens( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_scores: SpeculativeScores, + proposals: SpeculativeProposals, + max_proposal_len: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Determine which speculative tokens are accepted using the + probabilities of each token according to the proposer and scorer models. + + Returns a tuple of Tensors, one for the accepted token ids and one for + the logprobs according to the scoring model. + """ + proposal_lens_list = proposals.proposal_lens.tolist() + + # vLLM currently only supports proposal lens equal to zero or the batch + # proposal len. This adds some complexity (splitting the batch into spec + # and non spec sequences) and should be removed in the future. It can be + # done by supporting per-sequence proposal lens. + (_, spec_indices), (_, non_spec_indices) = split_batch_by_proposal_len( + seq_group_metadata_list, proposal_lens_list) + original_indices = spec_indices + non_spec_indices + + # Get probabilities of target model, including bonus tokens. + proposal_verifier_probs = proposal_scores.probs[spec_indices] + + # Get non-speculative sampled tokens from target model. + non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] + + # Get bonus tokens from target model. + bonus_token_ids = proposal_scores.token_ids[spec_indices, -1:] + + # Get probabilities according to proposal method. + proposal_probs = proposals.proposal_probs[spec_indices] + + # Get proposed tokens. + proposal_token_ids = proposals.proposal_token_ids[spec_indices] + + # Sampler arguments + sampler_extra_kwargs: Dict[str, Any] = {} + if self.generators and isinstance(self.spec_decode_sampler, + SpecDecodeStochasticBaseSampler): + sampler_extra_kwargs["seeded_seqs"] = { + idx: self.generators[sgm.request_id] + for idx, sgm in enumerate(seq_group_metadata_list) + if sgm.sampling_params.seed is not None + } + + accepted_token_ids = self.spec_decode_sampler( + target_with_bonus_probs=proposal_verifier_probs, + bonus_token_ids=bonus_token_ids, + draft_probs=proposal_probs, + draft_token_ids=proposal_token_ids, + **sampler_extra_kwargs, + ) + # Append output tokens from non-speculative sequences to + # the accepted token ids tensor. + non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + + 1).clone() + non_spec_token_ids[:, 1:] = -1 + accepted_token_ids = torch.cat( + [accepted_token_ids, non_spec_token_ids]) + logprobs = proposal_scores.logprobs + # Rearrange so that results are in the order of the original seq group + # metadata. + accepted_token_ids[original_indices] = accepted_token_ids.clone() + + hidden_states = proposal_scores.hidden_states + if hidden_states is not None: + # Contract hidden states based on accepted tokens + hs_size = hidden_states.shape[-1] + + accepted_index = accepted_token_ids + 1 # Convert -1 to 0 + accepted_index = accepted_index.count_nonzero(dim=1).add_(-1) + index = accepted_index[:, None, None].expand(-1, 1, hs_size) + second_last_token_hidden_states = hidden_states[:, -2] # b x d + hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d + # Store hidden states from target model for subsequent decode step + self.previous_hidden_states = HiddenStates( + hidden_states, seq_group_metadata_list, + second_last_token_hidden_states) + + return accepted_token_ids, logprobs + + def _create_output_sampler_list( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + accepted_token_ids: torch.Tensor, # shape: [batch_size, k+1] + target_logprobs: torch.Tensor, # shape: [batch_size, k+1, vocab_size] + k: int, + stage_times: Tuple[float, float, float], + ) -> List[SamplerOutput]: + """Given the accepted token ids, create a list of SamplerOutput. + + The output is padded with -1 tokens such that each sequence has + the same number of outputs. + """ + batch_size, num_steps = accepted_token_ids.shape + accepted_token_ids_by_step = accepted_token_ids.transpose(0, 1) + if self._disable_logprobs: + # We are skipping the logprobs. Hence don't serialize the + # logprobs related tensors from the GPU. Instead create + # empty/dummy lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_dummy_logprob_lists( + batch_size, num_steps, + self.scorer_worker.model_config.max_logprobs) + else: + # Organize input tensors by step instead of by sequence. + target_logprobs_by_step = target_logprobs.transpose(0, 1) + # Serialize all tensors into Python lists. + (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, + topk_logprobs_by_step, topk_indices_by_step) =\ + self._create_logprob_lists_from_tensors( + target_logprobs_by_step, accepted_token_ids_by_step, + self.scorer_worker.model_config.max_logprobs) + + # Get the sequence ids and num_logprobs (sampling parameter) in the + # batch. + seq_ids, request_ids_seq_ids_mapping = get_all_seq_ids_and_request_ids( + seq_group_metadata_list) + + num_logprobs_per_seq = get_all_num_logprobs(seq_group_metadata_list) + + # Serialize tensor to CPU Python list. + accepted_token_ids_by_step = accepted_token_ids_by_step.tolist() + + # Construct the output on a per-step, per-sequence basis. + sampler_output_list: List[SamplerOutput] = [] + for step_index in range(num_steps): + if all(token_id == -1 + for token_id in accepted_token_ids_by_step[step_index]): + break + + step_output_token_ids: List[CompletionSequenceGroupOutput] = [] + for sequence_index in range(batch_size): + # Each sequence may have a different num_logprobs; retrieve it. + num_logprobs = num_logprobs_per_seq[sequence_index] + step_output_token_ids.append( + create_sequence_group_output( + token_id=accepted_token_ids_by_step[step_index] + [sequence_index], + token_id_logprob_rank=accepted_token_id_ranks_by_step[ + step_index][sequence_index], + token_id_logprob=accepted_token_id_logprobs_by_step[ + step_index][sequence_index], + seq_id=seq_ids[sequence_index], + topk_token_ids=topk_indices_by_step[step_index] + [sequence_index][:num_logprobs], + topk_logprobs=topk_logprobs_by_step[step_index] + [sequence_index][:num_logprobs], + )) + sampler_output_list.append( + SamplerOutput(outputs=step_output_token_ids)) + + # Populate the data structures needed to keep track of sequences with + # bonus tokens. + self._track_sequences_with_bonus_tokens(seq_ids, + request_ids_seq_ids_mapping, + accepted_token_ids_by_step) + maybe_rejsample_metrics = ( + self._metrics.maybe_collect_rejsample_metrics(k)) + if maybe_rejsample_metrics is not None: + sampler_output_list[ + 0].spec_decode_worker_metrics = maybe_rejsample_metrics + + # Log time spent in each stage periodically. + # This is periodic because the rejection sampler emits metrics + # periodically. + self._maybe_log_stage_times(*stage_times) + + return sampler_output_list + + def _maybe_log_stage_times(self, average_time_per_proposal_tok_ms: float, + scoring_time_ms: float, + verification_time_ms: float) -> None: + """Log the speculative stage times. If stat logging is disabled, do + nothing. + """ + if self._disable_log_stats: + return + + logger.info( + "SpecDecodeWorker stage times: " + "average_time_per_proposal_tok_ms=%.02f " + "scoring_time_ms=%.02f verification_time_ms=%.02f", + average_time_per_proposal_tok_ms, scoring_time_ms, + verification_time_ms) + + def _create_dummy_logprob_lists( + self, + batch_size: int, + num_steps: int, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four dummy lists representing token probabilities + and their ranks. + + This method initializes and returns: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + batch_size (int): The size of the batch. + num_steps (int): The number of steps in the sequence. + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing four dummy lists as described above. + """ + accepted_token_id_ranks_by_step = [[-1] * batch_size + for _ in range(num_steps)] + accepted_token_id_logprobs_by_step = [[0.0] * batch_size + for _ in range(num_steps)] + topk_logprobs_by_step: List[List[List[Optional[float]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + topk_indices_by_step: List[List[List[Optional[int]]]] = [[ + [None] * num_top_k for _ in range(batch_size) + ] for _ in range(num_steps)] + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _create_logprob_lists_from_tensors( + self, + target_logprobs_by_step: torch.Tensor, + accepted_token_ids_by_step: torch.Tensor, + num_top_k: int, + ) -> Tuple[List[List[int]], List[List[float]], + List[List[List[Optional[float]]]], + List[List[List[Optional[int]]]]]: + """ + Creates and returns four lists representing token probabilities and + their ranks. + + This method initializes and returns four lists containing: + - The ranks of the accepted tokens, shaped (num_steps, batch_size) + - The log probabilities of the accepted tokens, + shaped (num_steps, batch_size) + - The log probabilities of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + - The token IDs of the top k tokens, + shaped (num_steps, batch_size, num_top_k) + + Args: + target_logprobs_by_step (torch.Tensor): Tensor representing the + log probabilities of the target model, + shaped (num_steps, batch_size, vocab_size) + accepted_token_ids_by_step (torch.Tensor): Tensor representing + the accepted token_ids, shaped (num_steps, batch_size) + num_top_k (int): The number of top-k token log probabilities to + return. + + Returns: + A tuple containing the lists as described above. + """ + # Serialize all tensors to CPU Python lists. + # Get the logprobs/rank of the accepted tokens. + (accepted_token_id_ranks_by_step_tensor, + accepted_token_id_logprobs_by_step_tensor + ) = get_sampled_token_logprobs( + logprob_tensor=target_logprobs_by_step, + sampled_token_ids=accepted_token_ids_by_step, + ) + # Get the top-k logprobs (which may or may not include the + # logprob of the accepted token). + (topk_logprobs_by_step_tensor, + topk_indices_by_step_tensor) = target_logprobs_by_step.topk( + k=num_top_k, + dim=-1, + ) + accepted_token_id_ranks_by_step = ( + accepted_token_id_ranks_by_step_tensor.tolist()) + accepted_token_id_logprobs_by_step = ( + accepted_token_id_logprobs_by_step_tensor.tolist()) + topk_logprobs_by_step = topk_logprobs_by_step_tensor.tolist() + topk_indices_by_step = topk_indices_by_step_tensor.tolist() + return (accepted_token_id_ranks_by_step, + accepted_token_id_logprobs_by_step, topk_logprobs_by_step, + topk_indices_by_step) + + def _track_finished_requests(self, execute_model_req: ExecuteModelRequest): + """ + Removes the finished requests and their associated sequence ids from + internal book keeping data structures. + """ + for finished_request in execute_model_req.finished_requests_ids: + for seq_id in self._request_id_seq_id_mapping[finished_request]: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + del self._request_id_seq_id_mapping[finished_request] + + def _track_sequences_with_bonus_tokens( + self, seq_ids: List[int], + request_ids_seq_ids_mapping: Dict[str, Set[int]], + accepted_token_ids_by_step: List[List[int]]): + """ + Updates the internal data structures which keep track of sequences + which have been assigned bonus tokens in their last forward pass. + """ + for seq_index, seq_id in enumerate(seq_ids): + last_token_id = accepted_token_ids_by_step[-1][seq_index] + if last_token_id == -1: + self._seq_with_bonus_token_in_last_step.discard(seq_id) + else: + self._seq_with_bonus_token_in_last_step.add(seq_id) + for request_id, sequences in request_ids_seq_ids_mapping.items(): + self._request_id_seq_id_mapping[request_id].update(sequences) + + @cached_property + def _vocab_size(self) -> int: + """Get the vocab size of the model and make sure it's consistent between + draft and target workers. + """ + vocab_sizes = [ + worker.vocab_size + for worker in [self.proposer_worker, self.scorer_worker] + ] + assert all(vocab_sizes[0] == vocab_size for vocab_size in vocab_sizes) + return vocab_sizes[0] + + @property + def rank(self): + return self.scorer_worker.rank + + @property + def device(self): + return self.scorer_worker.device + + @property + def _driver_rank(self) -> int: + return 0 + + def get_cache_block_size_bytes(self): + """Return the size of a cache block in bytes. + + This function is only used to compose workers within a SpecDecodeWorker. + We leave composing a SpecDecodeWorker within a SpecDecodeWorker + undefined for now, although it could be implemented in the future. + See https://arxiv.org/abs/2308.04623. + """ + raise NotImplementedError + + +def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int, + proposer_cache_block_size_bytes: int, + total_num_gpu_blocks: int) -> int: + """Given total_num_gpu_blocks, the number of GPU blocks that could be + allocate to the target model, this function calculates how many blocks + should be given to the draft and target model. + + Note that usually the block size, in bytes, of each model is different, + as it's a function of number of KV/layer, number of heads, and hidden + dimension size. + + Since the target and draft models allocate the same number of blocks, we + simply calculate the number of blocks where if allocated by both models, + the total memory usage from KV cache is no larger than the number of + blocks allocatable by the target model alone. + """ + new_num_gpu_blocks = int( + total_num_gpu_blocks * scorer_cache_block_size_bytes / + (proposer_cache_block_size_bytes + scorer_cache_block_size_bytes)) + + return new_num_gpu_blocks + + +def prepare_prefill_hidden_states( + prefill_hidden_states: torch.Tensor) -> HiddenStates: + # For prefill step in proposer, we run the model for N-1 tokens + # because Nth token will be processed in the first decode step. For + # N-1 tokens, the input should be 0:N-1 hidden states which should + # be concatanated with 1:N token (since output of scorer has to be + # the input for proposer). Therefore, we shift the hidden states to + # align n-1th hidden state with nth token. + return HiddenStates(prefill_hidden_states.roll( + shifts=1, dims=0)) if prefill_hidden_states is not None else None diff --git a/vllm/spec_decode/target_model_runner.py b/vllm/spec_decode/target_model_runner.py new file mode 100644 index 0000000..2bb7af7 --- /dev/null +++ b/vllm/spec_decode/target_model_runner.py @@ -0,0 +1,69 @@ +from typing import List, Optional + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.sequence import SequenceGroupMetadata +from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, + ModelRunner) + + +class TargetModelRunner(ModelRunner): + """Specialized model runner for speculative decoding target model. + In speculative decoding, the log probabilities selected finally may not + be the same ones as selected by the target model sampling. This means + that the time spent in the log probability calculation of the target model + is time wasted, since we calculate log probabilities after deciding which + tokens are accepted. For this reason disabling log probabilities in the + target model will make decode faster. The model runner sets the + SamplingMetadata parameters according to whether log probabilities are + requested or not. + """ + + def __init__(self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None): + # An internal boolean member variable to indicate if token log + # probabilities are needed or not. + self.disable_logprobs = True + super().__init__( + model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + cache_config=cache_config, + load_config=load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + return_hidden_states=return_hidden_states, + observability_config=observability_config, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPUWithSamplingMetadata: + model_input: ModelInputForGPUWithSamplingMetadata = super( + ).prepare_model_input(seq_group_metadata_list, virtual_engine, + finished_requests_ids) + # If token log probabilities is disabled then skip generating sampler + # CPU output. We directly serialize the GPU sampled_token_id tensors + # as needed. If log probabilities is enabled then synchronize all the + # sampling related tensors which includes the logprobs tensors. + model_input.sampling_metadata.skip_sampler_cpu_output = ( + self.disable_logprobs) + return model_input diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py new file mode 100644 index 0000000..f6a52a5 --- /dev/null +++ b/vllm/spec_decode/top1_proposer.py @@ -0,0 +1,272 @@ +from typing import List, Optional, Set, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeProposer) +from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase +from vllm.spec_decode.util import sampler_output_to_torch + + +class Top1Proposer(SpeculativeProposer): + """Helper class which separates out sequences which would exceed the max + model length when speculated upon. + + This allows combinations of models such as JackFram/llama-68m draft with + meta-llama/Llama2-13b-chat-hf, as llama-68m has max_position_embeddings of + 2048 while Llama2-13b has max_position_embeddings of 4096. + + We treat the sequences which exceed the proposal draft model length as + "non-spec sequences". Essentially they skip the draft model and go through + normal decoding in the target model. + + Currently, only proposal_lens of 0 and k are supported, where k is a global + batch proposal length. In the future vLLM should support per-sequence + proposal lengths. + """ + + def __init__( + self, + worker: ProposerWorkerBase, + device: str, + vocab_size: int, + max_proposal_len: Optional[int] = None, + ): + self._worker = worker + self._device = device + self.max_proposal_len = max_proposal_len + self._vocab_size = vocab_size + + def get_spec_proposals( + self, + execute_model_req: ExecuteModelRequest, + seq_ids_with_bonus_token_in_last_step: Set[int], + ) -> SpeculativeProposals: + """Get speculative proposals given the input batch. + + Sequences which would exceed the max model length are skipped during + speculation. + """ + proposal_len = execute_model_req.num_lookahead_slots + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + # Split speculative- and non-speculative- sequences. + ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) = self._split_by_proposal_len(seq_group_metadata_list, proposal_len) + + if nonzero_proposal_len_seqs: + # Speculate tokens using the draft worker for the speculative + # sequences. + # If sampler_transposed is true, then maybe_sampler_output's + # token_ids is like [batch] format in proposal_len size list, + # while if it is false, the format would be [proposal_len] + # in batch size list + hidden_states = execute_model_req.previous_hidden_states + if hidden_states is not None: + hidden_states.prune(nonzero_proposal_len_seqs) + nonzero_execute_model_req = ExecuteModelRequest( + seq_group_metadata_list=nonzero_proposal_len_seqs, + num_lookahead_slots=proposal_len, + previous_hidden_states=hidden_states, + ) + maybe_sampler_output, transposed = self._worker.sampler_output( + execute_model_req=nonzero_execute_model_req, + sample_len=proposal_len, + seq_ids_with_bonus_token_in_last_step=\ + seq_ids_with_bonus_token_in_last_step, + ) + ( + proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + ) = self._remove_no_proposal_seqs(proposal_lens, + maybe_sampler_output, + nonzero_proposal_len_indices, + transposed) + else: + # If no sequences can be speculated, set sampler output to None. + maybe_sampler_output = None + transposed = False + + # Combine speculative- and non-speculative sequences into the same + # representation. + proposal_tokens, proposal_probs, proposal_lens = self._merge_outputs( + batch_size=len(seq_group_metadata_list), + proposal_len=proposal_len, + maybe_sampler_output=maybe_sampler_output, + proposal_lens=proposal_lens, + nonzero_proposal_len_indices=nonzero_proposal_len_indices, + sampler_transposed=transposed, + ) + + proposals = SpeculativeProposals( + proposal_token_ids=proposal_tokens, + proposal_probs=proposal_probs, + proposal_lens=proposal_lens, + no_proposals=maybe_sampler_output is None) + + return proposals + + def _split_by_proposal_len( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_len: int, + ) -> Tuple[List[int], List[SequenceGroupMetadata], List[int]]: + """Split sequences by two groups: + 1. Sequences with non-zero proposal length. + 2. Sequences with zero proposal length (due to disabled speculation + or exceed the maximum model length). + """ + + proposal_lens: List[int] = [] + nonzero_proposal_len_seqs: List[SequenceGroupMetadata] = [] + nonzero_proposal_len_indices: List[int] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + # The speculative decoding for this request has been disabled + # (e.g. due to high traffic). + if seq_group_metadata.num_speculative_tokens == 0: + proposal_lens.append(0) + continue + + seq_data = next(iter(seq_group_metadata.seq_data.values())) + seq_len = seq_data.get_len() + + # Currently only proposal lens of 0 or the global batch proposal len + # are supported. + # If max_proposal_len is defined, then we shall not exceed this + # quota for nonzero_proposal + new_k = 0 + if (self.max_proposal_len is None + or seq_len + proposal_len < self.max_proposal_len): + new_k = proposal_len + nonzero_proposal_len_seqs.append(seq_group_metadata) + nonzero_proposal_len_indices.append(i) + proposal_lens.append(new_k) + seq_group_metadata.num_speculative_tokens = new_k + + return ( + proposal_lens, + nonzero_proposal_len_seqs, + nonzero_proposal_len_indices, + ) + + @staticmethod + def _remove_no_proposal_seqs(proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices, transposed): + """Remove sequences from nonzero_proposal_len_indices and reset + their proposal_len to 0 the draft worker does not provide a proposal + (maybe_sampler_output=None). This can avoid scoring overheads. + """ + + # If maybe_sampler_output is None, then the draft worker did not + # provide a proposal for any sequence and thus no action needed. + # Also we do not support transposed maybe_sampler_output for now + # because it seems not straightforward for draft workers outputting + # transposed sampler outputs to handle the case of no proposal. + if maybe_sampler_output is None or transposed: + return (proposal_lens, maybe_sampler_output, + nonzero_proposal_len_indices) + + new_proposal_lens: List[int] = [] + new_nonzero_proposal_len_indices: List[int] = [] + new_maybe_sampler_output: List[SamplerOutput] = [] + nonzero_proposal_len_idx_ptr = 0 + seq_idx = 0 + while seq_idx < len( + proposal_lens) and nonzero_proposal_len_idx_ptr < len( + nonzero_proposal_len_indices): + if seq_idx < nonzero_proposal_len_indices[ + nonzero_proposal_len_idx_ptr]: + # Sequence is not in the original nonzero_proposal_len_indices, + # meaning that it has a proposal length of 0 before sending to + # the draft worker. + assert proposal_lens[seq_idx] == 0 + new_proposal_lens.append(0) + else: + # Sequence is in the original nonzero_proposal_len_indices + if maybe_sampler_output[nonzero_proposal_len_idx_ptr] is None: + # but does not have a proposal from the draft worker. + new_proposal_lens.append(0) + else: + # and has a proposal from the draft worker. Add it to the + # new nonzero proposal list and keep the sampler output. + new_proposal_lens.append(proposal_lens[seq_idx]) + new_nonzero_proposal_len_indices.append(seq_idx) + new_maybe_sampler_output.append( + maybe_sampler_output[nonzero_proposal_len_idx_ptr]) + nonzero_proposal_len_idx_ptr += 1 + seq_idx += 1 + + # The remaining sequences should have proposal length of 0. + new_proposal_lens.extend(proposal_lens[seq_idx:]) + + # We assume sampler_output will not be a list of all Nones. + # In this case this function should not be called. + assert new_maybe_sampler_output + return (new_proposal_lens, new_maybe_sampler_output, + new_nonzero_proposal_len_indices) + + def _merge_outputs( + self, + batch_size: int, + proposal_len: int, + maybe_sampler_output: Optional[List[SamplerOutput]], + proposal_lens: List[int], + nonzero_proposal_len_indices: List[int], + sampler_transposed: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """After speculations are produced, merge the speculation results with + the skipped sequences. + """ + if maybe_sampler_output is None: + # If no speculative tokens, the sampler output will be None. + # In this case we return empty proposals. + proposal_tokens = torch.tensor(-1, + dtype=torch.long, + device=self._device).expand( + batch_size, proposal_len) + proposal_probs = torch.tensor(0, + dtype=torch.float32, + device=self._device).expand( + batch_size, proposal_len, + self._vocab_size) + proposal_lens_tensor = torch.tensor(0, + dtype=torch.long, + device=self._device).expand( + len(proposal_lens)) + return proposal_tokens, proposal_probs, proposal_lens_tensor + + sampler_output = maybe_sampler_output + proposal_tokens, proposal_probs, *_ = sampler_output_to_torch( + sampler_output, sampler_transposed) + + # Now, reformat the output GPU tensors such that each sequence has + # a proposal. the proposal can be empty, e.g. [-1, -1, -1] + + entire_proposal_tokens = proposal_tokens.new_full( + size=(batch_size, *proposal_tokens.shape[1:]), + fill_value=-1, + ) + entire_proposal_tokens[nonzero_proposal_len_indices] = proposal_tokens + entire_proposal_probs = proposal_probs.new_zeros( + batch_size, + *proposal_probs.shape[1:], + ) + entire_proposal_probs[nonzero_proposal_len_indices] = proposal_probs + + proposal_tokens, proposal_probs = ( + entire_proposal_tokens, + entire_proposal_probs, + ) + + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = proposal_len + + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py new file mode 100644 index 0000000..193ef87 --- /dev/null +++ b/vllm/spec_decode/util.py @@ -0,0 +1,268 @@ +import time +from contextlib import contextmanager +from typing import Dict, List, Optional, Sequence, Tuple + +import torch + +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, + PromptLogprobs, SequenceGroupMetadata, + SequenceOutput) + +SeqId = int + + +def get_all_num_logprobs( + seq_group_metadata_list: List[SequenceGroupMetadata]) -> List[int]: + """Given a list of SequenceGroupMetadata, create a list of all num_logprobs. + + If the sampling params do not call for any logprobs, return 0 for that + sequence. + """ + + all_num_logprobs: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + num_logprobs = seq_group_metadata.sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + all_num_logprobs.append(num_logprobs) + + return all_num_logprobs + + +def get_sampled_token_logprobs( + # shape [num_steps, batch_size, vocab_size] + logprob_tensor: torch.Tensor, + sampled_token_ids: torch.Tensor, # shape [num_steps, batch_size] +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the logprobs for the sampled tokens. Returns the ranks and logprobs. + """ + num_steps, batch_size, vocab_size = logprob_tensor.shape + + selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1), + torch.arange(batch_size), + sampled_token_ids, ] + expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( + -1, -1, vocab_size) + sampled_token_ids_ranks = (logprob_tensor > + expanded_selected_logprobs).sum(-1).add_(1) + + return sampled_token_ids_ranks, selected_logprobs + + +def create_logprobs_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], +) -> Dict[int, Logprob]: + """Create a Logprob Dict for a token given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + # vLLM logprobs always include the sampled token. In addition, the user may + # request topk-logprobs (where top-k varies per user up to max_logprobs). + logprobs: Dict[int, Logprob] = { + token_id: Logprob( + logprob=token_id_logprob, + rank=token_id_logprob_rank, + ), + } + logprobs.update({ + topk_token_id: Logprob( + logprob=topk_logprob if topk_logprob is not None else 0.0, + rank=topk_index + 1, + ) + for topk_index, (topk_token_id, topk_logprob) \ + in enumerate(zip(topk_token_ids, topk_logprobs)) \ + if topk_token_id is not None + }) + + return logprobs + + +def create_sequence_group_output( + token_id: int, + token_id_logprob_rank: int, + token_id_logprob: float, + seq_id: SeqId, + topk_token_ids: List[Optional[int]], + topk_logprobs: List[Optional[float]], + prompt_logprobs: Optional[PromptLogprobs] = None, +) -> CompletionSequenceGroupOutput: + """Create a SequenceGroupOutput given the sampling results. + + Args: + token_id (int): The sampled token for the sequence. + token_id_logprob_rank (int): The logprob rank of the sampled token. + token_id_logprob (float): The logprob value of the sampled token. + seq_id (int): The sequence id. + topk_token_ids (List[Optional[int]]): The list of top-k token ids. + topk_logprobs (List[Optional[float]]): The list of top-k logprobs. + """ + + logprobs = create_logprobs_output( + token_id, + token_id_logprob_rank, + token_id_logprob, + topk_token_ids, + topk_logprobs, + ) + + return CompletionSequenceGroupOutput( + samples=[ + SequenceOutput(parent_seq_id=seq_id, + output_token=token_id, + logprobs=logprobs) + ], + prompt_logprobs=prompt_logprobs, + ) + + +def split_batch_by_proposal_len( + seq_group_metadata_list: List[SequenceGroupMetadata], + proposal_lens: List[int], +) -> Tuple[Tuple[List[SequenceGroupMetadata], List[int]], Tuple[ + List[SequenceGroupMetadata], List[int]]]: + """Utility function that splits a batch based on whether the proposal len is + zero or not. We should remove this once vLLM supports per-sequence proposal + lens in a batch. + """ + + nonzero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + zero_lists: Tuple[List[SequenceGroupMetadata], List[int]] = ([], []) + for i, (seq_group, proposal_len) in enumerate( + zip(seq_group_metadata_list, proposal_lens)): + seq_groups, indices = nonzero_lists if proposal_len else zero_lists + seq_groups.append(seq_group) + indices.append(i) + return nonzero_lists, zero_lists + + +def sampler_output_to_torch( + sampler_output_list: Sequence[SamplerOutput], sampler_transposed: bool +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """Utility function which converts a list of SamplerOutput to tensors. + + sampler_transposed here is used as the indicator for whether + we need do additional tensor transpose logic here. + + Returns: + sampled_token_ids: torch.Tensor + shape: [batch_size, len(sampler_output_list)] + + sampled_token_probs: torch.Tensor + shape: [batch_size, len(sampler_output_list), vocab_size] + """ + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_probs = torch.stack( + [ + sampler_output.sampled_token_probs + for sampler_output in sampler_output_list + ], + dim=0, + ) + + # shape: [batch_size, num_sampler_output, vocab_size] + sampled_token_logprobs = torch.stack( + [sampler_output.logprobs for sampler_output in sampler_output_list], + dim=0, + ) + + # shape: [batch_size, num_sampler_output] + sampled_token_ids = torch.stack( + [ + sampler_output.sampled_token_ids.flatten() + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_token_probs = sampled_token_probs.transpose(0, 1) + sampled_token_logprobs = sampled_token_logprobs.transpose(0, 1) + sampled_token_ids = sampled_token_ids.transpose(0, 1) + + if sampler_output_list[0].hidden_states is not None: + # shape: [batch_size, num_sampler_output, hidden_dim] + sampled_hidden_states = torch.stack( + [ + sampler_output.hidden_states + for sampler_output in sampler_output_list + ], + dim=0, + ) + + if sampler_transposed: + sampled_hidden_states = sampled_hidden_states.transpose(0, 1) + else: + sampled_hidden_states = None + + return (sampled_token_ids, sampled_token_probs, sampled_token_logprobs, + sampled_hidden_states) + + +def maybe_mock_device_tensors(sampler_output: SamplerOutput, batch_size: int, + vocab_size: int, device: str) -> None: + """Helper method which mocks out the GPU tensors in SamplerOutput with dummy + values. This will be removed in PR 7/9. + https://docs.google.com/document/d/1rE4pr3IdspRw97XbImY4fS9IWYuJJ3HGtL7AdIKGrw8/edit#heading=h.qijw1sdidrer + """ + values = [ + sampler_output.sampled_token_probs, sampler_output.sampled_token_ids + ] + assert all(v is None for v in values) or not any(v is None for v in values) + if not any(v is None for v in values): + # Do nothing if the tensors are already created (usually in unit tests). + return + + # Softmax to ensure valid probs. + sampler_output.sampled_token_probs = torch.nn.functional.softmax( + torch.rand(batch_size, vocab_size, dtype=torch.float32, device=device), + dim=-1) + + sampler_output.sampled_token_ids = torch.randint(low=10, + high=100, + size=(batch_size, ), + dtype=torch.long, + device=device) + + +@contextmanager +def nvtx_range(msg, *args, **kwargs): + """ + Context manager / decorator that pushes an NVTX range at the beginning + of its scope, and pops it at the end. If extra arguments are given, + they are passed as arguments to msg.format(). + + If running with cuda graphs, you must enable nsys cuda graph profiling. + + Arguments: + msg (string): message to associate with the range + """ + torch.cuda.nvtx.range_push(msg.format(*args, **kwargs)) + try: + yield + finally: + torch.cuda.nvtx.range_pop() + + +class Timer: + """Basic timer context manager for measuring CPU time. + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.end_time = time.time() + self.elapsed_time_s = self.end_time - self.start_time + self.elapsed_time_ms = self.elapsed_time_s * 1000 diff --git a/vllm/tracing.py b/vllm/tracing.py new file mode 100644 index 0000000..50068d8 --- /dev/null +++ b/vllm/tracing.py @@ -0,0 +1,119 @@ +import os +from typing import Mapping, Optional + +from vllm.logger import init_logger +from vllm.utils import run_once + +TRACE_HEADERS = ["traceparent", "tracestate"] + +logger = init_logger(__name__) + +_is_otel_imported = False +otel_import_error_traceback: Optional[str] = None +try: + from opentelemetry.context.context import Context + from opentelemetry.sdk.environment_variables import ( + OTEL_EXPORTER_OTLP_TRACES_PROTOCOL) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import BatchSpanProcessor + from opentelemetry.semconv_ai import SpanAttributes as BaseSpanAttributes + from opentelemetry.trace import SpanKind, Tracer, set_tracer_provider + from opentelemetry.trace.propagation.tracecontext import ( + TraceContextTextMapPropagator) + _is_otel_imported = True +except ImportError: + # Capture and format traceback to provide detailed context for the import + # error. Only the string representation of the error is retained to avoid + # memory leaks. + # See https://github.com/vllm-project/vllm/pull/7266#discussion_r1707395458 + import traceback + otel_import_error_traceback = traceback.format_exc() + + class Context: # type: ignore + pass + + class BaseSpanAttributes: # type: ignore + pass + + class SpanKind: # type: ignore + pass + + class Tracer: # type: ignore + pass + + +def is_otel_available() -> bool: + return _is_otel_imported + + +def init_tracer(instrumenting_module_name: str, + otlp_traces_endpoint: str) -> Optional[Tracer]: + if not is_otel_available(): + raise ValueError( + "OpenTelemetry is not available. Unable to initialize " + "a tracer. Ensure OpenTelemetry packages are installed. " + f"Original error:\n{otel_import_error_traceback}") + trace_provider = TracerProvider() + + span_exporter = get_span_exporter(otlp_traces_endpoint) + trace_provider.add_span_processor(BatchSpanProcessor(span_exporter)) + set_tracer_provider(trace_provider) + + tracer = trace_provider.get_tracer(instrumenting_module_name) + return tracer + + +def get_span_exporter(endpoint): + protocol = os.environ.get(OTEL_EXPORTER_OTLP_TRACES_PROTOCOL, "grpc") + if protocol == "grpc": + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import ( + OTLPSpanExporter) + elif protocol == "http/protobuf": + from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter) # type: ignore + else: + raise ValueError( + f"Unsupported OTLP protocol '{protocol}' is configured") + + return OTLPSpanExporter(endpoint=endpoint) + + +def extract_trace_context( + headers: Optional[Mapping[str, str]]) -> Optional[Context]: + if is_otel_available(): + headers = headers or {} + return TraceContextTextMapPropagator().extract(headers) + else: + return None + + +def extract_trace_headers(headers: Mapping[str, str]) -> Mapping[str, str]: + + return {h: headers[h] for h in TRACE_HEADERS if h in headers} + + +class SpanAttributes(BaseSpanAttributes): + # The following span attribute names are added here because they are missing + # from the Semantic Conventions for LLM. + LLM_REQUEST_ID = "gen_ai.request.id" + LLM_REQUEST_N = "gen_ai.request.n" + LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences" + LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue" + LLM_LATENCY_TIME_TO_FIRST_TOKEN = "gen_ai.latency.time_to_first_token" + LLM_LATENCY_E2E = "gen_ai.latency.e2e" + LLM_LATENCY_TIME_IN_SCHEDULER = "gen_ai.latency.time_in_scheduler" + # Time taken in the forward pass for this across all workers + LLM_LATENCY_TIME_IN_MODEL_FORWARD = "gen_ai.latency.time_in_model_forward" + # Time taken in the model execute function. This will include model + # forward, block/sync across workers, cpu-gpu sync time and sampling time. + LLM_LATENCY_TIME_IN_MODEL_EXECUTE = "gen_ai.latency.time_in_model_execute" + + +def contains_trace_headers(headers: Mapping[str, str]) -> bool: + return any(h in headers for h in TRACE_HEADERS) + + +@run_once +def log_tracing_disabled_warning() -> None: + logger.warning( + "Received a request with trace context but tracing is disabled") diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py new file mode 100644 index 0000000..74ca396 --- /dev/null +++ b/vllm/transformers_utils/__init__.py @@ -0,0 +1,17 @@ +from vllm.envs import VLLM_USE_MODELSCOPE + +if VLLM_USE_MODELSCOPE: + # Patch here, before each import happens + import modelscope + from packaging import version + + # patch_hub begins from modelscope>=1.18.1 + if version.parse(modelscope.__version__) <= version.parse('1.18.0'): + raise ImportError( + 'Using vLLM with ModelScope needs modelscope>=1.18.1, please ' + 'install by `pip install modelscope>=1.18.1`') + + from modelscope.utils.hf_util import patch_hub + + # Patch hub to download models from modelscope to speed up. + patch_hub() diff --git a/vllm/transformers_utils/__pycache__/__init__.cpython-310.pyc b/vllm/transformers_utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..8b699dd Binary files /dev/null and b/vllm/transformers_utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/transformers_utils/__pycache__/config.cpython-310.pyc b/vllm/transformers_utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000..852d8f1 Binary files /dev/null and b/vllm/transformers_utils/__pycache__/config.cpython-310.pyc differ diff --git a/vllm/transformers_utils/__pycache__/detokenizer.cpython-310.pyc b/vllm/transformers_utils/__pycache__/detokenizer.cpython-310.pyc new file mode 100644 index 0000000..cb27203 Binary files /dev/null and b/vllm/transformers_utils/__pycache__/detokenizer.cpython-310.pyc differ diff --git a/vllm/transformers_utils/__pycache__/processor.cpython-310.pyc b/vllm/transformers_utils/__pycache__/processor.cpython-310.pyc new file mode 100644 index 0000000..a535b54 Binary files /dev/null and b/vllm/transformers_utils/__pycache__/processor.cpython-310.pyc differ diff --git a/vllm/transformers_utils/__pycache__/tokenizer.cpython-310.pyc b/vllm/transformers_utils/__pycache__/tokenizer.cpython-310.pyc new file mode 100644 index 0000000..60b4473 Binary files /dev/null and b/vllm/transformers_utils/__pycache__/tokenizer.cpython-310.pyc differ diff --git a/vllm/transformers_utils/__pycache__/utils.cpython-310.pyc b/vllm/transformers_utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..77fb60c Binary files /dev/null and b/vllm/transformers_utils/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py new file mode 100644 index 0000000..bdc6068 --- /dev/null +++ b/vllm/transformers_utils/config.py @@ -0,0 +1,360 @@ +import enum +import json +from pathlib import Path +from typing import Any, Dict, Optional, Type, Union + +import huggingface_hub +from huggingface_hub import (file_exists, hf_hub_download, + try_to_load_from_cache) +from transformers import GenerationConfig, PretrainedConfig +from transformers.models.auto.image_processing_auto import ( + get_image_processor_config) +from transformers.models.auto.modeling_auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME + +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + EAGLEConfig, ExaoneConfig, + InternVLChatConfig, JAISConfig, + MedusaConfig, MllamaConfig, + MLPSpeculatorConfig, MPTConfig, + NemotronConfig, NVLM_D_Config, + Qwen2VLConfig, RWConfig, + SolarConfig, UltravoxConfig) +# yapf: enable +from vllm.transformers_utils.utils import check_gguf_file + +if VLLM_USE_MODELSCOPE: + from modelscope import AutoConfig +else: + from transformers import AutoConfig + +MISTRAL_CONFIG_NAME = "params.json" + +logger = init_logger(__name__) + +_CONFIG_REGISTRY_OVERRIDE_HF: Dict[str, Type[PretrainedConfig]] = { + "mllama": MllamaConfig +} + +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + "chatglm": ChatGLMConfig, + "dbrx": DbrxConfig, + "mpt": MPTConfig, + "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) + "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) + "jais": JAISConfig, + "mlp_speculator": MLPSpeculatorConfig, + "medusa": MedusaConfig, + "eagle": EAGLEConfig, + "exaone": ExaoneConfig, + "internvl_chat": InternVLChatConfig, + "nemotron": NemotronConfig, + "NVLM_D": NVLM_D_Config, + "solar": SolarConfig, + "ultravox": UltravoxConfig, + "qwen2_vl": Qwen2VLConfig, + **_CONFIG_REGISTRY_OVERRIDE_HF +} + + +class ConfigFormat(str, enum.Enum): + AUTO = "auto" + HF = "hf" + MISTRAL = "mistral" + + +def file_or_path_exists(model: Union[str, Path], config_name, revision, + token) -> bool: + if Path(model).exists(): + return (Path(model) / config_name).is_file() + + # Offline mode support: Check if config file is cached already + cached_filepath = try_to_load_from_cache(repo_id=model, + filename=config_name, + revision=revision) + if isinstance(cached_filepath, str): + # The config file exists in cache- we can continue trying to load + return True + + # NB: file_exists will only check for the existence of the config file on + # hf_hub. This will fail in offline mode. + try: + return file_exists(model, config_name, revision=revision, token=token) + except huggingface_hub.errors.OfflineModeIsEnabled: + # Don't raise in offline mode, all we know is that we don't have this + # file cached. + return False + + +def get_config( + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, + **kwargs, +) -> PretrainedConfig: + # Separate model folder from file path for GGUF models + + is_gguf = check_gguf_file(model) + if is_gguf: + kwargs["gguf_file"] = Path(model).name + model = Path(model).parent + + if config_format == ConfigFormat.AUTO: + if is_gguf or file_or_path_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.HF + elif file_or_path_exists(model, + MISTRAL_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.MISTRAL + else: + # If we're in offline mode and found no valid config format, then + # raise an offline mode error to indicate to the user that they + # don't have files cached and may need to go online. + # This is conveniently triggered by calling file_exists(). + file_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")) + + raise ValueError(f"No supported config format found in {model}") + + if config_format == ConfigFormat.HF: + config_dict, _ = PretrainedConfig.get_config_dict( + model, revision=revision, code_revision=code_revision, **kwargs) + + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + else: + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + elif config_format == ConfigFormat.MISTRAL: + config = load_params_config(model, revision) + else: + raise ValueError(f"Unsupported config format: {config_format}") + + # Special architecture mapping check for GGUF models + if is_gguf: + if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: + raise RuntimeError( + f"Can't get gguf config for {config.model_type}.") + model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] + config.update({"architectures": [model_type]}) + + for key, value in [ + ("rope_scaling", rope_scaling), + ("rope_theta", rope_theta), + ]: + if value is not None: + logger.info( + "Updating %s from %r to %r", + key, + getattr(config, key, None), + value, + ) + config.update({key: value}) + + return config +def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None: + """Try to register HF model configuration class to serialize by value + With trust_remote_code, the config class is typically an instance of a + custom class imported from the HF modules cache. The class will not be + importable in spawned workers by default (and won't exist at all on + other nodes), which breaks serialization of the config. + In this function we tell the cloudpickle serialization library to pass + instances of these generated classes by value instead of by reference, + i.e. the class definition is serialized along with its data so that the + class module does not need to be importable on the receiving end. This + registration only works if the modules cache has already been + initialized. + See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs + """ + if not trust_remote_code: + return + + try: + import transformers_modules + except ImportError: + logger.debug("Could not import transformers_modules used for remote" + " code. If remote code is not needed remove" + " `--trust-remote-code`.") + return + + try: + import cloudpickle + cloudpickle.register_pickle_by_value(transformers_modules) + + # ray vendors its own version of cloudpickle + from vllm.executor.ray_utils import ray + if ray: + ray.cloudpickle.register_pickle_by_value(transformers_modules) + + # multiprocessing uses pickle to serialize arguments when using spawn + # Here we get pickle to use cloudpickle to serialize ModelConfig objects + # that contain instances of the custom config class to avoid + # serialization problems if the generated module (and model) has a `.` + # in its name + import multiprocessing + import pickle + + from vllm.config import ModelConfig + + def _reduce_modelconfig(mc: ModelConfig): + return (pickle.loads, (cloudpickle.dumps(mc), )) + + multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig) + + except Exception as e: + logger.warning( + "Unable to register remote classes used by" + " trust_remote_code with by-value serialization. This may" + " lead to a later error. If remote code is not needed" + " remove `--trust-remote-code`", + exc_info=e) + +def load_params_config(model, revision) -> PretrainedConfig: + # This function loads a params.json config which + # should be used when loading models in mistral format + + config_file_name = "params.json" + + config_path = Path(model) / config_file_name + + if not config_path.is_file(): + config_path = Path( + hf_hub_download(model, config_file_name, revision=revision)) + + with open(config_path, "r") as file: + config_dict = json.load(file) + + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + + def recurse_elems(elem: Any): + if isinstance(elem, dict): + config_dict = {} + for key, value in elem.items(): + key = config_mapping.get(key, key) + config_dict[key] = recurse_elems(value) + return PretrainedConfig(**config_dict) + else: + return elem + + config_dict["model_type"] = config_dict.get("model_type", "transformer") + config_dict["hidden_act"] = config_dict.get("activation", "silu") + config_dict["tie_word_embeddings"] = config_dict.get( + "tie_embeddings", False) + config_dict["max_seq_len"] = config_dict.get("max_seq_len", 128_000) + config_dict["max_position_embeddings"] = config_dict.get( + "max_position_embeddings", 128_000) + + if config_dict.get("moe") is not None: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + if config_dict.get("vision_encoder") is not None: + multimodal_config = config_dict.pop("vision_encoder") + + config_dict = { + "text_config": config_dict, + "vision_config": multimodal_config + } + config_dict["architectures"] = ["PixtralForConditionalGeneration"] + config_dict["model_type"] = "pixtral" + + config = recurse_elems(config_dict) + return config + + +def get_hf_image_processor_config( + model: Union[str, Path], + revision: Optional[str] = None, + **kwargs, +) -> Dict[str, Any]: + # ModelScope does not provide an interface for image_processor + if VLLM_USE_MODELSCOPE: + return dict() + # Separate model folder from file path for GGUF models + if check_gguf_file(model): + model = Path(model).parent + return get_image_processor_config(model, revision=revision, **kwargs) + + +def get_hf_text_config(config: PretrainedConfig): + """Get the "sub" config relevant to llm for multi modal models. + No op for pure text models. + """ + if hasattr(config, "text_config"): + # The code operates under the assumption that text_config should have + # `num_attention_heads` (among others). Assert here to fail early + # if transformers config doesn't align with this assumption. + assert hasattr(config.text_config, "num_attention_heads") + return config.text_config + else: + return config + + +def try_get_generation_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, +) -> Optional[GenerationConfig]: + try: + return GenerationConfig.from_pretrained( + model, + revision=revision, + ) + except OSError: # Not found + try: + config = get_config( + model, + trust_remote_code=trust_remote_code, + revision=revision, + ) + return GenerationConfig.from_model_config(config) + except OSError: # Not found + return None diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py new file mode 100644 index 0000000..8d6385d --- /dev/null +++ b/vllm/transformers_utils/configs/__init__.py @@ -0,0 +1,40 @@ +from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.configs.dbrx import DbrxConfig +from vllm.transformers_utils.configs.eagle import EAGLEConfig +from vllm.transformers_utils.configs.exaone import ExaoneConfig +# RWConfig is for the original tiiuae/falcon-40b(-instruct) and +# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the +# `FalconConfig` class from the official HuggingFace transformers library. +from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.internvl import InternVLChatConfig +from vllm.transformers_utils.configs.jais import JAISConfig +from vllm.transformers_utils.configs.medusa import MedusaConfig +from vllm.transformers_utils.configs.mllama import MllamaConfig +from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig +from vllm.transformers_utils.configs.mpt import MPTConfig +from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.nvlm_d import NVLM_D_Config +from vllm.transformers_utils.configs.qwen2vl import (Qwen2VLConfig, + Qwen2VLVisionConfig) +from vllm.transformers_utils.configs.solar import SolarConfig +from vllm.transformers_utils.configs.ultravox import UltravoxConfig + +__all__ = [ + "ChatGLMConfig", + "DbrxConfig", + "MPTConfig", + "RWConfig", + "InternVLChatConfig", + "JAISConfig", + "MedusaConfig", + "EAGLEConfig", + "ExaoneConfig", + "MllamaConfig", + "MLPSpeculatorConfig", + "NemotronConfig", + "NVLM_D_Config", + "SolarConfig", + "UltravoxConfig", + "Qwen2VLConfig", + "Qwen2VLVisionConfig", +] diff --git a/vllm/transformers_utils/configs/__pycache__/__init__.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..400b5e3 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/arctic.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/arctic.cpython-310.pyc new file mode 100644 index 0000000..84f5429 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/arctic.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-310.pyc new file mode 100644 index 0000000..5b934d3 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/chatglm.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-310.pyc new file mode 100644 index 0000000..2f8b67a Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/dbrx.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/eagle.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/eagle.cpython-310.pyc new file mode 100644 index 0000000..39a4a50 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/eagle.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/exaone.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/exaone.cpython-310.pyc new file mode 100644 index 0000000..161adbc Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/exaone.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/falcon.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/falcon.cpython-310.pyc new file mode 100644 index 0000000..aab95b5 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/falcon.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/internvl.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/internvl.cpython-310.pyc new file mode 100644 index 0000000..02d777a Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/internvl.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/jais.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/jais.cpython-310.pyc new file mode 100644 index 0000000..851452d Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/jais.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/medusa.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/medusa.cpython-310.pyc new file mode 100644 index 0000000..dd42292 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/medusa.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/mllama.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/mllama.cpython-310.pyc new file mode 100644 index 0000000..6fcf925 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/mllama.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/mlp_speculator.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/mlp_speculator.cpython-310.pyc new file mode 100644 index 0000000..a807a6d Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/mlp_speculator.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/mpt.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/mpt.cpython-310.pyc new file mode 100644 index 0000000..18a7758 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/mpt.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/nemotron.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/nemotron.cpython-310.pyc new file mode 100644 index 0000000..d8b8e04 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/nemotron.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/nvlm_d.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/nvlm_d.cpython-310.pyc new file mode 100644 index 0000000..3e46783 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/nvlm_d.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/qwen2vl.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/qwen2vl.cpython-310.pyc new file mode 100644 index 0000000..6d53853 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/qwen2vl.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/solar.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/solar.cpython-310.pyc new file mode 100644 index 0000000..72f652c Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/solar.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/__pycache__/ultravox.cpython-310.pyc b/vllm/transformers_utils/configs/__pycache__/ultravox.cpython-310.pyc new file mode 100644 index 0000000..2385a07 Binary files /dev/null and b/vllm/transformers_utils/configs/__pycache__/ultravox.cpython-310.pyc differ diff --git a/vllm/transformers_utils/configs/arctic.py b/vllm/transformers_utils/configs/arctic.py new file mode 100644 index 0000000..7780bf5 --- /dev/null +++ b/vllm/transformers_utils/configs/arctic.py @@ -0,0 +1,204 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/Snowflake/snowflake-arctic-instruct/blob/main/configuration_arctic.py +""" Arctic model configuration""" + +from dataclasses import asdict, dataclass +from typing import Any, Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "arctic": "https://huggingface.co/Snowflake/snowflake-arctic-instruct/tree/main/config.json", +} + + +@dataclass +class ArcticLoraConfig: + lora_r: int = 64 + lora_alpha: float = 16 + shard_base_weights: bool = False + + +@dataclass +class ArcticQuantizationConfig: + q_bits: int = 8 + rounding: str = "nearest" + mantissa_bits: int = 3 + group_size: int = 128 + + +class ArcticConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ArcticModel`]. It is used to instantiate an + Arctic model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the #TODO(rsamdani): add what model has the default config.. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Arctic model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`ArcticModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Arctic's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import ArcticModel, ArcticConfig + + >>> # Initializing a Arctic 7B style configuration TODO(rsamdani): verify which model does the default configuration correspond to. + >>> configuration = ArcticConfig() + + >>> # Initializing a model from the Arctic 7B style configuration + >>> model = ArcticModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "arctic" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=1, + num_local_experts=8, + router_aux_loss_coef=0.001, + moe_layer_frequency=2, + parallel_attn_mlp_res=False, + moe_train_capacity_factor=1, + moe_eval_capacity_factor=1, + enable_expert_tensor_parallelism=False, + moe_min_capacity=0, + moe_token_dropping=True, + quantization=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_aux_loss_coef = router_aux_loss_coef + self.moe_layer_frequency = moe_layer_frequency + self.moe_train_capacity_factor = moe_train_capacity_factor + self.moe_eval_capacity_factor = moe_eval_capacity_factor + self.enable_expert_tensor_parallelism = enable_expert_tensor_parallelism + self.moe_min_capacity = moe_min_capacity + self.moe_token_dropping = moe_token_dropping + self.parallel_attn_mlp_res = parallel_attn_mlp_res + if isinstance(quantization, dict): + self.quantization = ArcticQuantizationConfig(**quantization) + else: + self.quantization = quantization + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": + result = super().from_dict(config_dict, **kwargs) + config = result[0] if isinstance(result, tuple) else result + if isinstance(config.quantization, dict): + config.quantization = ArcticQuantizationConfig(**config.quantization) + return result + + def to_dict(self) -> Dict[str, Any]: + ret = super().to_dict() + if isinstance(ret["quantization"], ArcticQuantizationConfig): + ret["quantization"] = asdict(ret["quantization"]) + return ret diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py new file mode 100644 index 0000000..49d2b8d --- /dev/null +++ b/vllm/transformers_utils/configs/chatglm.py @@ -0,0 +1,70 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/dbrx.py b/vllm/transformers_utils/configs/dbrx.py new file mode 100644 index 0000000..0dc9664 --- /dev/null +++ b/vllm/transformers_utils/configs/dbrx.py @@ -0,0 +1,278 @@ +# yapf: disable +# ruff: noqa: E501 +# coding=utf-8 +# Copied from +# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py +"""Dbrx configuration.""" + +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*, defaults to None): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (float): The base frequency for rope. + """ + + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["attn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of + the activation function along with any additional keyword arguments. + ffn_hidden_size (int, optional): The hidden size of the feedforward network. + moe_num_experts (int, optional): The number of experts in the mixture of experts layer. + moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. + moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. + uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. + This should only be used for benchmarking purposes. + """ + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["ffn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", config_dict["model_type"], cls.model_type) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxConfig(PretrainedConfig): + """Configuration class for Dbrx. + + [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 6144): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 32768): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError( + "tie_word_embeddings is not supported for Dbrx models." + ) + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py new file mode 100644 index 0000000..b357a78 --- /dev/null +++ b/vllm/transformers_utils/configs/eagle.py @@ -0,0 +1,49 @@ +import os +from typing import Optional, Union + +from transformers import AutoConfig, PretrainedConfig + + +class EAGLEConfig(PretrainedConfig): + model_type = "eagle" + + def __init__(self, + model: Union[PretrainedConfig, dict, None] = None, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + model_config = None if model is None else (AutoConfig.for_model( + **model) if isinstance(model, dict) else model) + + for k, v in kwargs.items(): + if k != "architectures" and k != "model_type" and hasattr( + model_config, k): + setattr(model_config, k, v) + + self.model = model_config + + if self.model is None: + self.truncated_vocab_size = None + else: + self.truncated_vocab_size = self.model.vocab_size if \ + truncated_vocab_size is None else truncated_vocab_size + + if "architectures" not in kwargs: + kwargs["architectures"] = ["EAGLEModel"] + + super().__init__(**kwargs) + + if self.model is not None: + for k, v in self.model.to_dict().items(): + if not hasattr(self, k): + setattr(self, k, v) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "EAGLEConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + return cls.from_dict(config_dict, **kwargs) diff --git a/vllm/transformers_utils/configs/exaone.py b/vllm/transformers_utils/configs/exaone.py new file mode 100644 index 0000000..805b8ad --- /dev/null +++ b/vllm/transformers_utils/configs/exaone.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copied from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/configuration_exaone.py +# Copyright 2021 The LG AI Research EXAONE Lab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exaone model configuration""" + +from typing import Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {} + + +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class: + `~transformers.ExaoneModel`. It is used to instantiate a GPT Lingvo model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` + and can be used to control the model outputs. Read the documentation from : + class:`~transformers.PretrainedConfig` for more information. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50257): + Vocabulary size of the GPT Lingvo model. Defines the number of + different tokens that can be represented by the :obj:`inputs_ids` + passed when calling :class:`~transformers.ExaoneModel`. Vocabulary + size of the model. + Defines the different tokens that can be represented by the + `inputs_ids` passed to the forward method of :class: + `~transformers.EXAONEModel`. + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1 the model will use + Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed by meanpooling + all the original heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + rotary_pct (`float`, *optional*, defaults to 0.25): + percentage of hidden dimensions to allocate to rotary embeddings + intermediate_size (:obj:`int`, `optional`, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in + the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, + defaults to :obj:`"gelu_new"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, :obj:`"gelu"`, :obj:`"relu"`, + :obj:`"selu"` and :obj:`"gelu_new"` are supported. + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the + embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling + :class:`~transformers.EXAONEModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + Only relevant if ``config.is_decoder=True``. + gradient_checkpointing (:obj:`bool`, `optional`, + defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense + of slower backward pass. + Example:: + + >>> from transformers import ExoneModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = ExoneModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rotary_pct=0.25, + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rotary_pct = rotary_pct + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_logit_cap = kwargs.pop("use_logit_cap", False) + self.ln_no_scale = kwargs.pop("ln_no_scale", False) + self.use_gated = kwargs.pop("use_gated", False) + self.use_emb_norm = kwargs.pop("use_emb_norm", False) + self.use_rotary_pos = kwargs.pop("use_rotary_pos", False) + self.rotary_type = kwargs.pop("rotary_type", None) + self.scaling_factor = kwargs.pop("scaling_factor", 1) + self.use_absolute_pos = kwargs.pop("use_absolute_pos", True) + self.use_extra_logit = kwargs.pop("use_extra_logit", True) + self.rotary_expand_length = kwargs.pop("rotary_expand_length", None) + self.rotary_base = kwargs.pop("rotary_base", 10000.0) + self.use_qkv_fuse = kwargs.pop("use_qkv_fuse", False) + self.rescale_before_lm_head = kwargs.pop("rescale_before_lm_head", + (rotary_pct == 0.25)) + if self.use_rotary_pos: + self.use_absolute_pos = False diff --git a/vllm/transformers_utils/configs/falcon.py b/vllm/transformers_utils/configs/falcon.py new file mode 100644 index 0000000..65fdd29 --- /dev/null +++ b/vllm/transformers_utils/configs/falcon.py @@ -0,0 +1,89 @@ +# Adapted from +# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py +# Copyright 2023 The vLLM team. +# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon configuration""" +from transformers.configuration_utils import PretrainedConfig + + +class RWConfig(PretrainedConfig): + model_type = "falcon" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "num_kv_heads": "n_head_kv", + } + + def __init__( + self, + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + num_ln_in_parallel_attn=None, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + multi_query=True, + n_head_kv=None, + alibi=False, + bias=False, + parallel_attn=False, + new_decoder_architecture=False, + **kwargs, + ) -> None: + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.multi_query = multi_query + self.n_head_kv = 1 if n_head_kv is None else n_head_kv + self.alibi = alibi + self.bias = bias + self.parallel_attn = parallel_attn + self.new_decoder_architecture = new_decoder_architecture + + if self.hidden_size == 8192: + # Hack for falcon-40b + self.new_decoder_architecture = True + self.num_ln_in_parallel_attn = num_ln_in_parallel_attn + + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) + + @property + def head_dim(self): + return self.hidden_size // self.n_head + + @property + def rotary(self): + return not self.alibi diff --git a/vllm/transformers_utils/configs/internvl.py b/vllm/transformers_utils/configs/internvl.py new file mode 100644 index 0000000..ac24923 --- /dev/null +++ b/vllm/transformers_utils/configs/internvl.py @@ -0,0 +1,51 @@ +# Adapted from +# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py +# -------------------------------------------------------- +# InternVL +# Copyright (c) 2024 OpenGVLab +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- +from transformers.configuration_utils import PretrainedConfig + + +class InternVLChatConfig(PretrainedConfig): + model_type = 'internvl_chat' + is_composition = True + + def __init__(self, + vision_config=None, + llm_config=None, + use_backbone_lora=0, + use_llm_lora=0, + select_layer=-1, + force_image_size=None, + downsample_ratio=0.5, + template=None, + dynamic_image_size=False, + use_thumbnail=False, + ps_version='v1', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs): + super().__init__(**kwargs) + + if vision_config is None: + vision_config = {} + + if llm_config is None: + llm_config = {} + + self.vision_config = PretrainedConfig(**vision_config) + self.text_config = PretrainedConfig(**llm_config) + + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.select_layer = select_layer + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch diff --git a/vllm/transformers_utils/configs/jais.py b/vllm/transformers_utils/configs/jais.py new file mode 100644 index 0000000..b06a946 --- /dev/null +++ b/vllm/transformers_utils/configs/jais.py @@ -0,0 +1,236 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2023 Cerebras Systems. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""JAIS configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class JAISConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a + [`JAISModel`]. It is used to instantiate a JAIS model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used + to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the JAIS model. Defines the number of different + tokens that can be represented by the + `inputs_ids` passed when calling [`JAISModel`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used + with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the + Transformer encoder. + n_inner (`int`, *optional*, defaults to None): + Dimensionality of the inner feed-forward layers. `None` will set + it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu"`): + Activation function, to be selected in the list + `["relu", "silu", "gelu", "tanh", "gelu_new", "swiglu"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in + the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + scale_attn_by_inverse_layer_idx (`bool`, *optional*, + defaults to `False`): + Whether to additionally scale attention weights by + `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention + (dot-product) + and upcast attention dot-product/softmax to float() when training + with mixed precision. + position_embedding_type (`str`, *optional*, defaults to `"learned"`): + Positional embedding can be either `"alibi"` or `"learned"`. + mup_width_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale learning rate and initializers. Calculated + as (`d_model,0 / d_model`), where + `d_model` is the model's width and `d_model,0` is the proxy + model's width. + mup_embeddings_scale (`float`, *optional*, defaults to 1.0): + muP parameter to scale token and position embeddings. + mup_output_alpha (`float`, *optional*, defaults to 1.0): + muP parameter to scale output logits + (`output_logits_scale = mup_output_alpha * mup_width_scale`). + mup_scale_qk_dot_by_d (`bool`, *optional*, defaults to `False`): + Scale attention weights by dividing by hidden_size instead of + sqrt(hidden_size). Need to set scale_attn_weights to `True` as + well. + alibi_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for ALiBi + embeddings. Currently only supports linear + scaling strategy. Can specify either the scaling `factor` (must be + a float greater than 1) for fixed scaling + or `train_seq_len` for dynamic scaling on input samples with + sequence length > `train_seq_len`. The expected + formats are `{"type": strategy name, "factor": scaling factor}` or + `{"type": strategy name, + "train_seq_len": training sequence length}`. + architectures (`List`, *optional*, defaults to ['JAISLMHeadModel']): + architecture names for Jais. + + Example: + + ```python + >>> from transformers import JAISConfig, JAISModel + + >>> # Initializing a JAIS configuration + >>> configuration = JAISConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = JAISModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "jais" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1024, + n_embd=768, + n_layer=12, + n_head=12, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + position_embedding_type="learned", + mup_width_scale=1.0, + mup_embeddings_scale=1.0, + mup_output_alpha=1.0, + mup_scale_qk_dot_by_d=False, + alibi_scaling=None, + architectures=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.position_embedding_type = position_embedding_type + self.mup_width_scale = mup_width_scale + self.mup_embeddings_scale = mup_embeddings_scale + self.mup_output_alpha = mup_output_alpha + self.mup_scale_qk_dot_by_d = mup_scale_qk_dot_by_d + + self.alibi_scaling = alibi_scaling + self._alibi_scaling_validation() + if architectures is None: + architectures = ["JAISLMHeadModel"] + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + architectures=architectures, + **kwargs, + ) + + def _alibi_scaling_validation(self): + """ + Validate the `alibi_scaling` configuration. + """ + if self.alibi_scaling is None: + return + + if (not isinstance(self.alibi_scaling, dict) + or len(self.alibi_scaling) != 2): + raise ValueError( + "`alibi_scaling` must be a dictionary with two fields," + "`type` and `factor` or `type` and `train_seq_len`, " + f"got {self.alibi_scaling}") + alibi_scaling_type = self.alibi_scaling.get("type", None) + alibi_scaling_factor = self.alibi_scaling.get("factor", None) + alibi_dynamic_scaling = self.alibi_scaling.get("train_seq_len", None) + if alibi_scaling_type is None or alibi_scaling_type != "linear": + raise ValueError(f"`alibi_scaling`'s type field must be 'linear'," + f"got {alibi_scaling_type}") + if (alibi_scaling_factor is not None + and not isinstance(alibi_scaling_factor, float) + or (alibi_scaling_factor is not None + and alibi_scaling_factor <= 1.0)): + raise ValueError( + f"`alibi_scaling`'s factor field must be a float > 1.0," + f"got {alibi_scaling_factor}") + if (alibi_dynamic_scaling is not None + and not isinstance(alibi_dynamic_scaling, int) + or (alibi_dynamic_scaling is not None + and alibi_dynamic_scaling <= 1)): + raise ValueError( + f"`alibi_scaling`'s `train_seq_len` field must be an" + f"integer > 1, got {alibi_dynamic_scaling}") diff --git a/vllm/transformers_utils/configs/medusa.py b/vllm/transformers_utils/configs/medusa.py new file mode 100644 index 0000000..d71a083 --- /dev/null +++ b/vllm/transformers_utils/configs/medusa.py @@ -0,0 +1,60 @@ +import os +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class MedusaConfig(PretrainedConfig): + model_type = "medusa" + + def __init__(self, + hidden_size: int = 4096, + vocab_size: int = 32001, + num_heads: int = 5, + num_hidden_layers: int = 1, + max_paths: int = 64, + topk: int = 10, + truncated_vocab_size: Optional[int] = None, + **kwargs): + + self.hidden_size = hidden_size + self.vocab_size = vocab_size + self.num_heads = num_heads + self.num_hidden_layers = num_hidden_layers + self.max_paths = max_paths + self.topk = topk + self.max_seq_len = int(2**20) + self.truncated_vocab_size = vocab_size if truncated_vocab_size is None\ + else truncated_vocab_size + if "architectures" not in kwargs: + kwargs["architectures"] = ["MedusaModel"] + + super().__init__(**kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + **kwargs, + ) -> "MedusaConfig": + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + for k in list(config_dict.keys()): + if 'num' in k: + if 'heads' in k: + config_dict["num_heads"] = config_dict.pop(k) + elif 'layers' in k: + config_dict["num_hidden_layers"] = config_dict.pop(k) + return cls.from_dict(config_dict, **kwargs) + + @property + def num_attention_heads(self): + return 0 + + @property + def num_lookahead_tokens(self): + return self.num_heads + + @num_lookahead_tokens.setter + def num_lookahead_tokens(self, num_lookahead_tokens: int): + self.num_heads = num_lookahead_tokens diff --git a/vllm/transformers_utils/configs/mllama.py b/vllm/transformers_utils/configs/mllama.py new file mode 100644 index 0000000..49e766d --- /dev/null +++ b/vllm/transformers_utils/configs/mllama.py @@ -0,0 +1,28 @@ +from transformers.models.mllama import configuration_mllama as mllama_hf_config + + +class MllamaTextConfig(mllama_hf_config.MllamaTextConfig): + ''' + Use this class to override is_encoder_decoder: + - transformers regards mllama as is_encoder_decoder=False + - vllm needs is_encoder_decoder=True to enable cross-attention + ''' + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.is_encoder_decoder = True + + +class MllamaConfig(mllama_hf_config.MllamaConfig): + + def __init__( + self, + text_config=None, + **kwargs, + ): + if isinstance(text_config, dict): + text_config = MllamaTextConfig(**text_config) + super().__init__(text_config=text_config, **kwargs) diff --git a/vllm/transformers_utils/configs/mlp_speculator.py b/vllm/transformers_utils/configs/mlp_speculator.py new file mode 100644 index 0000000..946af4e --- /dev/null +++ b/vllm/transformers_utils/configs/mlp_speculator.py @@ -0,0 +1,65 @@ +from typing import List, Optional + +from transformers import PretrainedConfig + + +class MLPSpeculatorConfig(PretrainedConfig): + model_type = "mlp_speculator" + + attribute_map = { + "hidden_size": "emb_dim", + } + + def __init__(self, + vocab_size: int = 32000, + emb_dim: int = 4096, + inner_dim: int = 0, + n_predict: int = 3, + top_k_tokens_per_head: Optional[List[int]] = None, + n_candidates: int = 5, + tie_weights: bool = False, + scale_input: bool = False, + **kwargs): + """ + Initialize an MLPSpeculatorConfig + + Args: + vocab_size: int + the model vocab size + emb_dim: int + the model embedding dimension + inner_dim: int + the inner dimension of the model. If 0, will be the emb_dim. + n_predict: int + the number of lookaheads for the speculator + top_k_tokens_per_head: List[int] + Number of tokens to consider from each head when forming the + candidate tree. + For each candidate branch in the tree, head n produces topk[n] + additional sub-branches. + NOTE: This parameter is currently unused. + n_candidates: int + number of child candidates to create per sequence + tie_weights: bool + If true, use a single set of weights for every model + head/stage after the first. The initial projection + from the base model may have a different size, so that + stays separate. + scale_input: bool + if True, will scale the initial hidden states from + the base model. + """ + if top_k_tokens_per_head is None: + top_k_tokens_per_head = [5, 4, 3] + assert len(top_k_tokens_per_head) == n_predict + self.vocab_size = vocab_size + self.emb_dim = emb_dim + self.inner_dim = inner_dim + self.n_predict = n_predict + self.top_k_tokens_per_head = top_k_tokens_per_head + self.n_candidates = n_candidates + self.num_lookahead_tokens = n_predict + self.tie_weights = tie_weights + self.scale_input = scale_input + + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py new file mode 100644 index 0000000..497db0a --- /dev/null +++ b/vllm/transformers_utils/configs/mpt.py @@ -0,0 +1,178 @@ +# coding=utf-8 +# Copied from +# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py +"""A HuggingFace-style model configuration.""" +import warnings +from typing import Any, Dict, Optional, Union + +from transformers import PretrainedConfig + +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8 +} +ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'} +init_config_defaults: Dict = { + 'name': 'kaiming_normal_', + 'fan_mode': 'fan_in', + 'init_nonlinearity': 'relu', + 'init_div_is_residual': True, + 'emb_init_std': None, + 'emb_init_uniform_lim': None, + 'init_std': None, + 'init_gain': 0.0 +} + + +class MPTConfig(PretrainedConfig): + model_type = 'mpt' + attribute_map = { + 'num_attention_heads': 'n_heads', + 'hidden_size': 'd_model', + 'num_hidden_layers': 'n_layers', + } + + # pylint: disable=dangerous-default-value + def __init__(self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: Dict = attn_config_defaults, + ffn_config: Dict = ffn_config_defaults, + init_device: str = 'cpu', + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = False, + embedding_fraction: float = 1.0, + norm_type: str = 'low_precision_layernorm', + use_cache: bool = False, + init_config: Dict = init_config_defaults, + fc_type: str = 'torch', + verbose: Optional[int] = None, + **kwargs: Any): + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.attn_config = attn_config + self.ffn_config = ffn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + self.init_config = init_config + self.fc_type = fc_type + if verbose is not None: + warnings.warn(DeprecationWarning( + 'verbose argument for MPTConfig is now ignored and ' + 'will be removed. Use python_log_level instead.'), + stacklevel=2) + if 'name' in kwargs: + del kwargs['name'] + if 'loss_fn' in kwargs: + del kwargs['loss_fn'] + if self.attn_config.get('alibi', False): + self.learned_pos_emb = False + warnings.warn( + f'alibi is turned on, setting `learned_pos_emb` ' + f'to {self.learned_pos_emb}`', + stacklevel=2) + super().__init__(**kwargs) + self._validate_config() + + def _set_config_defaults( + self, config: Dict[str, Any], + config_defaults: Dict[str, Any]) -> Dict[str, Any]: + for (k, v) in config_defaults.items(): + if k not in config: + config[k] = v + return config + + def _validate_config(self) -> None: + self.attn_config = self._set_config_defaults(self.attn_config, + attn_config_defaults) + self.ffn_config = self._set_config_defaults(self.ffn_config, + ffn_config_defaults) + self.init_config = self._set_config_defaults(self.init_config, + init_config_defaults) + if self.d_model % self.n_heads != 0: + raise ValueError('d_model must be divisible by n_heads') + if any(( + prob < 0 or prob > 1 for prob in + [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop] + )): + raise ValueError( + "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are " + "probabilities and must be between 0 and 1") + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError( + f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm'] and self.attn_config[ + 'attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError( + 'prefix_lm only implemented with torch and triton attention.') + if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [ + 'torch', 'triton' + ]: + raise NotImplementedError( + 'alibi only implemented with torch and triton attention.') + if self.attn_config['attn_uses_sequence_id'] and self.attn_config[ + 'attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError( + 'attn_uses_sequence_id only implemented with torch ' + 'and triton attention.') + if self.embedding_fraction > 1 or self.embedding_fraction <= 0: + raise ValueError( + 'model.embedding_fraction must be between 0 (exclusive) ' + 'and 1 (inclusive)!') + if isinstance(self.logit_scale, + str) and self.logit_scale != 'inv_sqrt_d_model': + raise ValueError( + f"self.logit_scale={self.logit_scale!r} is not recognized as " + "an option; use numeric value or 'inv_sqrt_d_model'.") + if self.init_config.get('name', None) is None: + raise ValueError( + f"self.init_config={self.init_config!r} 'name' needs to be set." + ) + if not self.learned_pos_emb and (not self.attn_config['alibi']): + warnings.warn( + 'Positional information not being provided to the model.', + stacklevel=2) + if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': + try: + # pylint: disable=import-outside-toplevel + import transformer_engine.pytorch as te + del te + except Exception as exc: + raise ImportError( + 'TransformerEngine import fail. `fc_type: te` requires ' + 'TransformerEngine be installed. ' + 'The required version of transformer_engine also requires ' + 'FlashAttention v1.0.6 is installed:\n' + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' + ) from exc + if self.ffn_config['ffn_type'] == 'mptmlp': + self.ffn_config['fc_type'] = self.fc_type + elif self.ffn_config['ffn_type'] == 'te_ln_mlp': + self.ffn_config['bias'] = not self.no_bias diff --git a/vllm/transformers_utils/configs/nemotron.py b/vllm/transformers_utils/configs/nemotron.py new file mode 100644 index 0000000..139e6b3 --- /dev/null +++ b/vllm/transformers_utils/configs/nemotron.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. team. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Nemotron model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NemotronConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`NemotronModel`]. It is used to instantiate an Nemotron model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Nemotron-8B. + + Configuration objects inherit from [`PretrainedConfig`] and can be + used to control the model outputs. Read the documentation from + [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Nemotron model. Defines the number of + different tokens that can be represented by the + `inputs_ids` passed when calling [`NemotronModel`] + hidden_size (`int`, *optional*, defaults to 6144): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the + Transformer decoder. + head_dim (`int`, *optional*): + Projection weights dimension in multi-head attention. Set to + hidden_size // num_attention_heads if None + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention + (MQA) otherwise GQA is used. When converting a multi-head + checkpoint to a GQA checkpoint, each group key and value + head should be constructed by meanpooling all the original + heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it + is not specified, will default to `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"relu2"`): + The non-linear activation function (function or string) in the + decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used + with. + initializer_range (`float`, *optional*, defaults to 0.0134): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 3): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output + projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj and down_proj layers in the MLP + layers. + + ```python + >>> from transformers import NemotronModel, NemotronConfig + >>> # Initializing a Nemotron nemotron-15b style configuration + >>> configuration = NemotronConfig() + >>> # Initializing a model from the nemotron-15b style configuration + >>> model = NemotronModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "nemotron" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=6144, + intermediate_size=24576, + num_hidden_layers=32, + num_attention_heads=48, + head_dim=None, + num_key_value_heads=None, + hidden_act="relu2", + max_position_embeddings=4096, + initializer_range=0.0134, + norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=2, + eos_token_id=3, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.5, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + head_dim = head_dim or kwargs.get("kv_channels", None) + self.head_dim = head_dim if head_dim is not None else ( + hidden_size // num_attention_heads) + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_eps = norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + # for backward compatibility + partial_rotary_factor = kwargs.get("rope_percent", None) or kwargs.get( + "rope_percentage", None) or partial_rotary_factor + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, + dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, " + f"`type` and `factor`, got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", "dynamic" + ]: + raise ValueError( + "`rope_scaling`'s type field must be one of ['linear', " + f"'dynamic'], got {rope_scaling_type}") + if rope_scaling_factor is None or not isinstance( + rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError( + "`rope_scaling`'s factor field must be a float > 1, got " + f"{rope_scaling_factor}") diff --git a/vllm/transformers_utils/configs/nvlm_d.py b/vllm/transformers_utils/configs/nvlm_d.py new file mode 100644 index 0000000..8007176 --- /dev/null +++ b/vllm/transformers_utils/configs/nvlm_d.py @@ -0,0 +1,12 @@ +# Adapted from +# https://huggingface.co/nvidia/NVLM-D-72B/blob/main/configuration_nvlm_d.py +# -------------------------------------------------------- +# NVLM-D +# Copyright (c) 2024 NVIDIA +# Licensed under Apache 2.0 License [see LICENSE for details] +# -------------------------------------------------------- +from .internvl import InternVLChatConfig + + +class NVLM_D_Config(InternVLChatConfig): + model_type = 'NVLM_D' diff --git a/vllm/transformers_utils/configs/qwen2vl.py b/vllm/transformers_utils/configs/qwen2vl.py new file mode 100644 index 0000000..92dd962 --- /dev/null +++ b/vllm/transformers_utils/configs/qwen2vl.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen2VL model configuration""" + +import os +from typing import Union + +from transformers import PretrainedConfig + + +class Qwen2VLVisionConfig(PretrainedConfig): + model_type = "qwen2_vl" + + def __init__( + self, + depth=32, + embed_dim=1280, + hidden_size=3584, + hidden_act="quick_gelu", + mlp_ratio=4, + num_heads=16, + in_channels=3, + patch_size=14, + spatial_merge_size=2, + temporal_patch_size=2, + **kwargs, + ): + super().__init__(**kwargs) + + self.depth = depth + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.hidden_act = hidden_act + self.mlp_ratio = mlp_ratio + self.num_heads = num_heads + self.in_channels = in_channels + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, + os.PathLike], + **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs) + + if config_dict.get("model_type") == "qwen2_vl": + config_dict = config_dict["vision_config"] + + return cls.from_dict(config_dict, **kwargs) + + +class Qwen2VLConfig(PretrainedConfig): + + def __init__( + self, + vocab_size=152064, + hidden_size=8192, + intermediate_size=29568, + num_hidden_layers=80, + num_attention_heads=64, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-05, + use_cache=True, + tie_word_embeddings=False, + rope_theta=1000000.0, + use_sliding_window=False, + sliding_window=4096, + max_window_layers=80, + attention_dropout=0.0, + vision_config=None, + rope_scaling=None, + **kwargs, + ): + if isinstance(vision_config, dict): + self.vision_config = Qwen2VLVisionConfig(**vision_config) + elif vision_config is None: + self.vision_config = Qwen2VLVisionConfig() + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.use_sliding_window = use_sliding_window + self.sliding_window = sliding_window + self.max_window_layers = max_window_layers + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.rope_scaling = rope_scaling + + # NOTE: the following section from original transformers config + # for Qwen2-VL is commented out to address rope config loading issue + # + # if self.rope_scaling is not None and "type" in self.rope_scaling: + # if self.rope_scaling["type"] == "mrope": + # self.rope_scaling["type"] = "default" + # self.rope_scaling["rope_type"] = self.rope_scaling["type"] + # rope_config_validation(self) + + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) diff --git a/vllm/transformers_utils/configs/solar.py b/vllm/transformers_utils/configs/solar.py new file mode 100644 index 0000000..d5113bf --- /dev/null +++ b/vllm/transformers_utils/configs/solar.py @@ -0,0 +1,245 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Solar model configuration""" + +from transformers import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class SolarConfig(PretrainedConfig): + r""" + This is the configuration class to store + the configuration of a [`SolarModel`]. + It is used to instantiate an LLaMA model + according to the specified arguments, + defining the model architecture. + Instantiating a configuration with the + defaults will yield a similar + configuration to that of the LLaMA-7B. + Configuration objects inherit from [`PretrainedConfig`] + and can be used to control the model outputs. + Read the documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. + Defines the number of different tokens + that can be represented by the `inputs_ids` + passed when calling [`SolarModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer + in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that + should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, + the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) + otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed + by meanpooling all the original heads within that group. + For more details checkout [this paper] + (https://arxiv.org/pdf/2305.13245.pdf). + If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) + in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Solar 1 supports up to 2048 tokens, + Solar 2 up to 4096, CodeSolar up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of + the truncated_normal_initializer for initializing + all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return + the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank + used during pretraining. + Please refer to [this + document](https://huggingface.co/docs/ + transformers/main/ + perf_train_gpu_many#tensor-parallelism) + to understand more about it. This value is + necessary to ensure exact reproducibility + of the pretraining results. + Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for + the RoPE embeddings. + Currently supports two scaling + strategies: linear and dynamic. + Their scaling factor must be a float greater than 1. + The expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/ + dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking + API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value + and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj + layers in the MLP layers. + sliding_window (`int`, *optional*, defaults to 2047): + Sliding window attention window size. If not specified, + will default to `2047`. + ```python + >>> from transformers import SolarModel, SolarConfig + >>> # Initializing a Solar-pro style configuration + >>> configuration = SolarConfig() + >>> # Initializing a model from the Solar-pro style configuration + >>> model = SolarModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "solar" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + sliding_window=2047, + bskcn_1=None, + bskcn_2=None, + bskcn_3=None, + bskcn_4=None, + bskcn_tv=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + self.sliding_window = sliding_window + self.bskcn_1 = bskcn_1 if bskcn_1 is not None else [12, 20, 32, 44] + self.bskcn_2 = bskcn_2 if bskcn_2 is not None else [20, 32] + self.bskcn_3 = bskcn_3 if bskcn_3 is not None else [16, 24, 36, 48] + self.bskcn_4 = bskcn_4 if bskcn_4 is not None else [28, 40] + self.bskcn_tv = bskcn_tv if bskcn_tv is not None else [0.9, 0.8] + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if (not isinstance(self.rope_scaling, dict) + or len(self.rope_scaling) != 2): + raise ValueError( + "`rope_scaling` must be a dictionary with two fields," + " `type` and `factor`, " + f"got {self.rope_scaling}") + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in [ + "linear", + "dynamic", + ]: + raise ValueError(f"`rope_scaling`'s type field must be one of " + f"['linear', 'dynamic'], got {rope_scaling_type}") + if (rope_scaling_factor is None + or not isinstance(rope_scaling_factor, float) + or rope_scaling_factor <= 1.0): + raise ValueError( + f"`rope_scaling`'s factor field must be a float > 1," + f" got {rope_scaling_factor}") diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py new file mode 100644 index 0000000..f724bf7 --- /dev/null +++ b/vllm/transformers_utils/configs/ultravox.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py +from typing import Any, Dict, Optional + +import transformers + + +class UltravoxConfig(transformers.PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`UltravoxForConditionalGeneration`]. It is used to instantiate an + Ultravox model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` + or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + audio_token_index (`int`, *optional*, defaults to 32000): + The audio token index to encode the audio prompt. + stack_factor (`int`, *optional*, defaults to 8): + Audio downsampling factor for the multimodal projector. + norm_init (`float`, *optional*, defaults to 0.4): + The initialization value for the layer normalization. + projector_act (`str`, *optional*, defaults to `"swiglu"`): + The activation function used by the multimodal projector. + text_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the text model. + audio_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the audio model. + """ + + model_type = "ultravox" + is_composition = False + + def __init__( + self, + audio_config: Optional[Dict[str, Any]] = None, + text_config: Optional[Dict[str, Any]] = None, + audio_model_id: Optional[str] = None, + text_model_id: Optional[str] = None, + ignore_index: int = -100, + audio_token_index: int = 32000, + hidden_size: int = 4096, + stack_factor: int = 8, + norm_init: float = 0.4, + projector_act: str = "swiglu", + text_model_lora_config: Optional[Dict[str, Any]] = None, + audio_model_lora_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.ignore_index = ignore_index + + self.audio_model_id = audio_model_id + self.text_model_id = text_model_id + self.audio_token_index = audio_token_index + + self.hidden_size = hidden_size + self.stack_factor = stack_factor + self.norm_init = norm_init + self.projector_act = projector_act + + if text_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.text_config = get_config(text_model_id, + trust_remote_code=False) + else: + text_config = text_config or {} + self.text_config = transformers.CONFIG_MAPPING[text_config.get( + "model_type", "llama")](**text_config) + + if audio_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(audio_model_id, + trust_remote_code=False) + else: + audio_config = audio_config or {} + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( + "model_type", "whisper")](**audio_config) + + self.text_model_lora_config = text_model_lora_config or {} + self.audio_model_lora_config = audio_model_lora_config or {} + + self.vocab_size = self.text_config.vocab_size + + self.initializer_range = self.text_config.initializer_range + + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py new file mode 100644 index 0000000..2b418f3 --- /dev/null +++ b/vllm/transformers_utils/detokenizer.py @@ -0,0 +1,327 @@ +from typing import Dict, List, Optional, Tuple + +from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams, + Sequence, SequenceGroup) + +from .tokenizer import AnyTokenizer +from .tokenizer_group import BaseTokenizerGroup + + +class Detokenizer: + """Provides methods to decode the output of a model into text.""" + + def __init__(self, tokenizer_group: BaseTokenizerGroup): + self.tokenizer_group = tokenizer_group + + def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer: + """Returns the HF tokenizer to use for a given sequence.""" + return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request) + + def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, + prompt_logprobs: List[Optional[Dict[ + int, Logprob]]], + position_offset: int) -> None: + """Decodes the logprobs for the prompt of a sequence group. + + Args: + seq_group: The sequence group to decode. + prompt_logprobs: The logprobs to decode. + position_offset: Offset of the first index of the logprobs + relative to the start of the sequence (for chunked prefill). + + Returns: + The prompt logprobs with the decoded tokens. + """ + prms = seq_group.sampling_params + assert prms is not None + + # We can pick any sequence for the prompt. + seq = seq_group.get_seqs()[0] + # Only prompt, without the generated token. + all_token_ids = seq.get_token_ids() + prompt_token_ids = all_token_ids[:-1] + tokenizer = self.get_tokenizer_for_seq(seq) + prefix_offset = 0 + read_offset = 0 + next_iter_prefix_offset = 0 + next_iter_read_offset = 0 + next_iter_tokens: List[str] = [] + prev_tokens = None + + for token_position_in_logprob, prompt_logprobs_for_token in enumerate( + prompt_logprobs): + + # Absolute token position equals the index in the logprobs + # list plus the offset of the entire logprobs list relative + # to the start of the sequence. + token_position = token_position_in_logprob + position_offset + if not prompt_logprobs_for_token: + continue + for token_id, sample_logprob in prompt_logprobs_for_token.items(): + if (sample_logprob.decoded_token is None + and token_id != VLLM_INVALID_TOKEN_ID): + prompt_token_ids_with_token = ( + prompt_token_ids[:token_position] + [token_id]) + (new_tokens, new_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=prompt_token_ids_with_token, + prev_tokens=prev_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + + sample_logprob.decoded_token = new_text + + # Use the offsets & prev tokens corresponding to + # real tokens to ensure detokenization is consistent + # actual with prompt. + if token_id == all_token_ids[token_position]: + next_iter_prefix_offset = new_prefix_offset + next_iter_read_offset = new_read_offset + next_iter_tokens = new_tokens + + # Advance to the next token position. + prefix_offset = next_iter_prefix_offset + read_offset = next_iter_read_offset + if prev_tokens is None: + prev_tokens = next_iter_tokens + else: + prev_tokens.extend(next_iter_tokens) + + def decode_sequence_inplace(self, seq: Sequence, + prms: SamplingParams) -> int: + """Decodes the new token for a sequence. In-place operation. + + Args: + seq: The sequence to decode. + prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. + """ + all_input_ids = seq.get_token_ids() + token_id_generated_this_iteration = all_input_ids[-1] + tokenizer = self.get_tokenizer_for_seq(seq) + + # Convert prompt token IDs to tokens if necessary. + # Do it here so that we don't have to repeat this + # computation for each logprob. + if seq.tokens is None: + (seq.tokens, seq.prefix_offset, + seq.read_offset) = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=all_input_ids[:-1], + skip_special_tokens=prms.skip_special_tokens, + ) + + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + + # Decode logprobs + logprobs = seq.output_logprobs[-1] + if logprobs: + previous_tokens = all_input_ids[:-1] + for token_id, sample_logprob in logprobs.items(): + # If the token was generated this iteration, + # use the provided text. + if token_id == token_id_generated_this_iteration: + sample_logprob.decoded_token = new_decoded_token_text + continue + + if (sample_logprob.decoded_token is None + and token_id != VLLM_INVALID_TOKEN_ID): + all_input_ids_with_logprob = previous_tokens + [token_id] + (_, new_text, _, _) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids_with_logprob, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + sample_logprob.decoded_token = new_text + + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_decoded_token_text + + return len(new_decoded_token_text) + + +def _replace_none_with_empty(tokens: List[Optional[str]]): + for i, token in enumerate(tokens): + if token is None: + tokens[i] = "" + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: AnyTokenizer, + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts: List[str] = [] + current_sub_text: List[str] = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: AnyTokenizer, + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + # This is required to guard against out-of-vocab prompt token ids + _replace_none_with_empty(new_tokens) # type: ignore[arg-type] + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: AnyTokenizer, + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None + + # If the new token id is out of bounds, return an empty string. + if 0 <= new_token_id < len(tokenizer): + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] + else: + new_tokens = [""] + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py new file mode 100644 index 0000000..98663f7 --- /dev/null +++ b/vllm/transformers_utils/processor.py @@ -0,0 +1,94 @@ +from typing import Any, cast + + +def get_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load a processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor + from transformers.processing_utils import ProcessorMixin + + try: + processor = AutoProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the processor. If the processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(ProcessorMixin, processor) + + +def get_image_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load an image processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoImageProcessor + from transformers.image_processing_utils import BaseImageProcessor + + try: + processor = AutoImageProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the image processor. If the image processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return cast(BaseImageProcessor, processor) + + +def get_video_processor( + processor_name: str, + *args: Any, + trust_remote_code: bool = False, + **kwargs: Any, +): + """Load a video processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers.image_processing_utils import BaseImageProcessor + + processor = get_processor( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs, + ) + + return cast(BaseImageProcessor, processor.video_processor) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py new file mode 100644 index 0000000..94af238 --- /dev/null +++ b/vllm/transformers_utils/tokenizer.py @@ -0,0 +1,193 @@ +import os +import warnings +from pathlib import Path +from types import MethodType +from typing import Optional, Union + +import huggingface_hub +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizers import MistralTokenizer +from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import make_async + +logger = init_logger(__name__) + +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, + MistralTokenizer] + + +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_len = len(tokenizer) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + +def patch_padding_side(tokenizer: PreTrainedTokenizer) -> None: + """Patch _pad method to accept `padding_side` for older tokenizers.""" + orig_pad = tokenizer._pad + + def _pad( + self: PreTrainedTokenizer, + *args, + padding_side: Optional[str] = None, + **kwargs, + ): + if padding_side is not None and padding_side != self.padding_side: + msg = ("`padding_side` argument is not supported by " + f"{type(tokenizer).__name__} and will be ignored.") + warnings.warn(msg, stacklevel=2) + + return orig_pad(*args, **kwargs) + + tokenizer._pad = MethodType(_pad, tokenizer) + + +def get_tokenizer( + tokenizer_name: Union[str, Path], + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + **kwargs, +) -> AnyTokenizer: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + tokenizer_name = tokenizer_path + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + + # Separate model folder from file path for GGUF models + is_gguf = check_gguf_file(tokenizer_name) + if is_gguf: + kwargs["gguf_file"] = Path(tokenizer_name).name + tokenizer_name = Path(tokenizer_name).parent + + # if tokenizer is from official mistral org + is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" + if is_from_mistral_org and tokenizer_mode != "mistral": + warnings.warn( + 'It is strongly recommended to run mistral models with ' + '`--tokenizer_mode "mistral"` to ensure correct ' + 'encoding and decoding.', + FutureWarning, + stacklevel=2) + if tokenizer_mode == "mistral": + tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), + revision=revision) + else: + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, + # suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e)): + err_msg = ("Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + # NOTE: We can remove this after https://github.com/THUDM/ChatGLM3/issues/1324 + if type(tokenizer).__name__ in ("ChatGLMTokenizer", + "ChatGLM4Tokenizer"): + assert isinstance(tokenizer, PreTrainedTokenizer) + patch_padding_side(tokenizer) + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + tokenizer = get_cached_tokenizer(tokenizer) + + return tokenizer + + +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[AnyTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_path, *args, **kwargs) + except Exception as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + "No tokenizer found in %s, using base model tokenizer instead. " + "(Exception: %s)", lora_request.lora_path, e) + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) diff --git a/vllm/transformers_utils/tokenizer_group/__init__.py b/vllm/transformers_utils/tokenizer_group/__init__.py new file mode 100644 index 0000000..9a41492 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/__init__.py @@ -0,0 +1,52 @@ +from typing import Optional, Type + +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + TokenizerPoolConfig) +from vllm.executor.ray_utils import ray + +from .base_tokenizer_group import AnyTokenizer, BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup + +if ray: + from .ray_tokenizer_group import RayTokenizerGroupPool +else: + RayTokenizerGroupPool = None # type: ignore + + +def init_tokenizer_from_configs(model_config: ModelConfig, + scheduler_config: SchedulerConfig, + parallel_config: ParallelConfig, + enable_lora: bool): + init_kwargs = dict(tokenizer_id=model_config.tokenizer, + enable_lora=enable_lora, + max_num_seqs=scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=model_config.tokenizer_mode, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.tokenizer_revision) + + return get_tokenizer_group(parallel_config.tokenizer_pool_config, + **init_kwargs) + + +def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> BaseTokenizerGroup: + tokenizer_cls: Type[BaseTokenizerGroup] + if tokenizer_pool_config is None: + tokenizer_cls = TokenizerGroup + elif isinstance(tokenizer_pool_config.pool_type, type) and issubclass( + tokenizer_pool_config.pool_type, BaseTokenizerGroup): + tokenizer_cls = tokenizer_pool_config.pool_type + elif tokenizer_pool_config.pool_type == "ray": + if RayTokenizerGroupPool is None: + raise ImportError( + "RayTokenizerGroupPool is not available. Please install " + "the ray package to use the Ray tokenizer group pool.") + tokenizer_cls = RayTokenizerGroupPool + else: + raise ValueError( + f"Unknown pool type: {tokenizer_pool_config.pool_type}") + return tokenizer_cls.from_config(tokenizer_pool_config, **init_kwargs) + + +__all__ = ["AnyTokenizer", "get_tokenizer_group", "BaseTokenizerGroup"] diff --git a/vllm/transformers_utils/tokenizer_group/__pycache__/__init__.cpython-310.pyc b/vllm/transformers_utils/tokenizer_group/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..65433db Binary files /dev/null and b/vllm/transformers_utils/tokenizer_group/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/transformers_utils/tokenizer_group/__pycache__/base_tokenizer_group.cpython-310.pyc b/vllm/transformers_utils/tokenizer_group/__pycache__/base_tokenizer_group.cpython-310.pyc new file mode 100644 index 0000000..63750e8 Binary files /dev/null and b/vllm/transformers_utils/tokenizer_group/__pycache__/base_tokenizer_group.cpython-310.pyc differ diff --git a/vllm/transformers_utils/tokenizer_group/__pycache__/ray_tokenizer_group.cpython-310.pyc b/vllm/transformers_utils/tokenizer_group/__pycache__/ray_tokenizer_group.cpython-310.pyc new file mode 100644 index 0000000..43c7d5f Binary files /dev/null and b/vllm/transformers_utils/tokenizer_group/__pycache__/ray_tokenizer_group.cpython-310.pyc differ diff --git a/vllm/transformers_utils/tokenizer_group/__pycache__/tokenizer_group.cpython-310.pyc b/vllm/transformers_utils/tokenizer_group/__pycache__/tokenizer_group.cpython-310.pyc new file mode 100644 index 0000000..4554869 Binary files /dev/null and b/vllm/transformers_utils/tokenizer_group/__pycache__/tokenizer_group.cpython-310.pyc differ diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py new file mode 100644 index 0000000..8f78ef6 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -0,0 +1,66 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import AnyTokenizer + + +class BaseTokenizerGroup(ABC): + """A group of tokenizers that can be used for LoRA adapters.""" + + @classmethod + @abstractmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "BaseTokenizerGroup": + pass + + @abstractmethod + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + pass + + @abstractmethod + def get_max_input_len( + self, + lora_request: Optional[LoRARequest] = None, + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + pass + + @abstractmethod + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group.""" + pass + + @abstractmethod + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get a tokenizer for a LoRA request.""" + pass + + @abstractmethod + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + """Get a tokenizer for a LoRA request.""" + pass + + def check_health(self): + """Raise exception if the tokenizer group is unhealthy.""" + return diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py new file mode 100644 index 0000000..9a999a0 --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -0,0 +1,240 @@ +import asyncio +import os +from typing import List, Optional + +try: + from ray.exceptions import ActorDiedError # type: ignore +except ImportError: + # For older versions of Ray + from ray.exceptions import RayActorError as ActorDiedError # type: ignore +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy + +from vllm.config import TokenizerPoolConfig +from vllm.executor.ray_utils import ray +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import AnyTokenizer + +from .base_tokenizer_group import BaseTokenizerGroup +from .tokenizer_group import TokenizerGroup + +logger = init_logger(__name__) + + +class RayTokenizerGroupPool(BaseTokenizerGroup): + """A Ray-based pool of TokenizerGroups for async tokenization.""" + + # Class to use for workers making up the pool. + _worker_cls = TokenizerGroup + + @classmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "RayTokenizerGroupPool": + if not tokenizer_pool_config: + raise ValueError("tokenizer_pool_config must not be None.") + ray_actor_options = (tokenizer_pool_config.extra_config or { + "num_cpus": 0 + }) + ray_actor_options.setdefault( + "scheduling_strategy", + NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), soft=True)) + + # Carry over the env vars to the actors. + # This is necessary for API keys and such. + ray_actor_options.setdefault("runtime_env", {}) + _carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"]) + + init_kwargs["num_actors"] = tokenizer_pool_config.pool_size + init_kwargs["ray_actor_options"] = ray_actor_options + + return cls(**init_kwargs) + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], num_actors: int, + ray_actor_options: dict, **tokenizer_config): + # Store a local copy of the TokenizerGroup for quick access + # to underlying HF tokenizers. + self._tokenizer_config = { + "tokenizer_id": tokenizer_id, + "enable_lora": enable_lora, + "max_num_seqs": max_num_seqs, + "max_input_length": max_input_length, + **tokenizer_config + } + self._local_tokenizer_group = self._worker_cls( + **self._tokenizer_config, ) + + self._ray_tokenizer_group_cls = ray.remote( + self._worker_cls).options(**ray_actor_options) # type: ignore + self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)] + self._idle_actors: Optional[asyncio.Queue] = None + + # If set, actor is unhealthy. Will reraise on the next + # check_health call. + self._exception: Optional[ActorDiedError] = None + + def _init_actor(self) -> ray.ObjectRef: + return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config) + + @property + def pool_size(self) -> int: + return len(self.tokenizer_actors) + + def ping(self): + return ray.get([ + actor.ping.remote() # type: ignore + for actor in self.tokenizer_actors + ]) + + def _ensure_queue_initialized(self): + if self._idle_actors is None: + self._idle_actors = asyncio.Queue() + for actor in self.tokenizer_actors: + self._idle_actors.put_nowait(actor) + + def _finalize_encode(self, actor: ray.ObjectRef, + original_actor: ray.ObjectRef, actor_is_alive: bool): + assert self._idle_actors is not None + # Cleanup the dead actor. + if not actor_is_alive or original_actor is not actor: + self.tokenizer_actors.remove(original_actor) + if actor_is_alive: + # Put the actor back in the queue. + # This is done in a finally block to ensure that the actor is + # always put back in the queue, even if an exception/cancellation + # is raised. + self._idle_actors.put_nowait(actor) + # Add back the new actor. + if original_actor is not actor: + self.tokenizer_actors.append(actor) + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + The actor is then put back in the queue for future use. + This is blocking. + """ + self.check_health() + self._ensure_queue_initialized() + assert self._idle_actors is not None + + if self._idle_actors.empty(): + raise RuntimeError("No idle actors available.") + actor = self._idle_actors.get_nowait() + actor_is_alive = True + original_actor = actor + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + except ActorDiedError as e: + # If the actor is dead, we first try to reinitialize it. + logger.warning("%s died with ActorDiedError, reinitializing.", + actor, + exc_info=e) + actor = self._init_actor() + try: + ret = ray.get( + actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request)) + except ActorDiedError as e: + logger.error( + "%s died for second time in a row, marking " + "RayTokenizerGroupPool as unhealthy.", actor) + actor_is_alive = False + if not self._exception: + self._exception = e + self.check_health() + finally: + self._finalize_encode(actor, original_actor, actor_is_alive) + return ret + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + """Encode a prompt using the tokenizer group. + + We pick an idle actor and use it to encode the prompt. + If there are no idle actors, we wait until one becomes + available. + The actor is then put back in the queue for future use. + This is non-blocking. + """ + self.check_health() + self._ensure_queue_initialized() + assert self._idle_actors is not None + + actor = await self._idle_actors.get() + actor_is_alive = True + original_actor = actor + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + except ActorDiedError as e: + # If the actor is dead, we first try to reinitialize it. + logger.warning("%s died with ActorDiedError, reinitializing.", + actor, + exc_info=e) + actor = self._init_actor() + try: + ret = await actor.encode.remote(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + except ActorDiedError as e: + logger.error( + "%s died for second time in a row, marking " + "RayTokenizerGroupPool as unhealthy.", actor) + actor_is_alive = False + if not self._exception: + self._exception = e + self.check_health() + finally: + self._finalize_encode(actor, original_actor, actor_is_alive) + return ret + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self._local_tokenizer_group.get_max_input_len(lora_request) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return self._local_tokenizer_group.get_lora_tokenizer(lora_request) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + return await self._local_tokenizer_group.get_lora_tokenizer_async( + lora_request) + + def check_health(self): + if self._exception: + raise RuntimeError( + "TokenizerGroupPool is unhealthy.") from self._exception + + +def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None: + """Copy over all current process environment variables to the runtime_env. + + The variables in runtime_env will take precedence over the current process + environment variables. + + runtime_env will be modified in place.""" + env_vars = os.environ.copy() + runtime_env.setdefault("env_vars", {}) + env_vars.update(runtime_env["env_vars"]) + runtime_env["env_vars"] = env_vars diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py new file mode 100644 index 0000000..e516eea --- /dev/null +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -0,0 +1,99 @@ +from typing import List, Optional + +from vllm.config import TokenizerPoolConfig +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_lora_tokenizer, + get_lora_tokenizer_async, + get_tokenizer) +from vllm.utils import LRUCache + +from .base_tokenizer_group import BaseTokenizerGroup + + +class TokenizerGroup(BaseTokenizerGroup): + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + self.lora_tokenizers = LRUCache[AnyTokenizer]( + capacity=max_num_seqs if enable_lora else 0) + + @classmethod + def from_config(cls, tokenizer_pool_config: Optional[TokenizerPoolConfig], + **init_kwargs) -> "TokenizerGroup": + return cls(**init_kwargs) + + def ping(self) -> bool: + """Check if the tokenizer group is alive.""" + return True + + def get_max_input_len(self, + lora_request: Optional[LoRARequest] = None + ) -> Optional[int]: + """Get the maximum input length for the LoRA request.""" + return self.max_input_length + + def _raise_if_input_too_long(self, + encoded_tokens: List[int], + lora_request: Optional[LoRARequest] = None): + input_length = len(encoded_tokens) + if lora_request: + max_input_length = (lora_request.long_lora_max_len + or self.max_input_length) + else: + max_input_length = self.max_input_length + if max_input_length is not None and input_length > max_input_length: + raise ValueError("Input too long.", input_length, max_input_length) + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + ret = tokenizer.encode(prompt) + self._raise_if_input_too_long(ret, lora_request) + return ret + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers[lora_request.lora_int_id] + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers[lora_request.lora_int_id] diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py new file mode 100644 index 0000000..5f437d4 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -0,0 +1,3 @@ +from .mistral import MistralTokenizer + +__all__ = ["MistralTokenizer"] diff --git a/vllm/transformers_utils/tokenizers/__pycache__/__init__.cpython-310.pyc b/vllm/transformers_utils/tokenizers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..e4bba79 Binary files /dev/null and b/vllm/transformers_utils/tokenizers/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/transformers_utils/tokenizers/__pycache__/mistral.cpython-310.pyc b/vllm/transformers_utils/tokenizers/__pycache__/mistral.cpython-310.pyc new file mode 100644 index 0000000..27bceff Binary files /dev/null and b/vllm/transformers_utils/tokenizers/__pycache__/mistral.cpython-310.pyc differ diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py new file mode 100644 index 0000000..aae10d3 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -0,0 +1,229 @@ +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from huggingface_hub import HfApi, hf_hub_download +# yapf: disable +from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) +# yapf: enable +from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer) +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy, + Tekkenizer) + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam + + +@dataclass +class Encoding: + input_ids: List[int] + + +def find_tokenizer_file(files: List[str]): + file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") + + matched_files = [file for file in files if file_pattern.match(file)] + if len(matched_files) > 1: + raise OSError(f"Found {len(matched_files)} files matching the " + f"pattern: {file_pattern}. Make sure only one Mistral " + f"tokenizer is present in {files}.") + elif len(matched_files) == 0: + raise OSError(f"Found {len(matched_files)} files matching the " + f"pattern: {file_pattern}. Make sure that a Mistral " + f"tokenizer is present in {files}.") + + return matched_files[0] + + +class MistralTokenizer: + + def __init__(self, tokenizer: PublicMistralTokenizer) -> None: + self.mistral = tokenizer + self.instruct = tokenizer.instruct_tokenizer + + tokenizer_ = tokenizer.instruct_tokenizer.tokenizer + if isinstance(tokenizer_, Tekkenizer): + # Make sure special tokens will not raise + tokenizer_.special_token_policy = SpecialTokenPolicy.IGNORE + + self._vocab = { + token: idx + for idx, token in enumerate(tokenizer_.vocab()) + } + elif isinstance(tokenizer_, SentencePieceTokenizer): + self._vocab = { + token: idx + for idx, token in enumerate(tokenizer_.vocab()) + } + else: + raise TypeError(f"Unsupported tokenizer: {type(tokenizer_)}") + + self.tokenizer = tokenizer_ + + @classmethod + def from_pretrained(cls, + path_or_repo_id: str, + *, + revision: Optional[str] = None) -> "MistralTokenizer": + if not Path(path_or_repo_id).exists(): + assert len(path_or_repo_id.split("/")) == 2, ( + "You have either provided a non-existent path: " + "{path_or_repo_id} or an invalid HF Hub repo id.") + tokenizer_file = cls._download_mistral_tokenizer_from_hf( + path_or_repo_id, revision) + elif Path(path_or_repo_id).is_dir(): + tokenizer_file_name = find_tokenizer_file( + os.listdir(path_or_repo_id)) + tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) + else: + assert Path( + path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" + + mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) + return cls(mistral_tokenizer) + + @staticmethod + def _download_mistral_tokenizer_from_hf(tokenizer_name: str, + revision: Optional[str]) -> str: + api = HfApi() + repo_info = api.model_info(tokenizer_name) + files = [s.rfilename for s in repo_info.siblings] + + filename = find_tokenizer_file(files) + + tokenizer_file = hf_hub_download(tokenizer_name, + filename=filename, + revision=revision) + return tokenizer_file + + # the following attributes are set to fit VLLM's design + @property + def all_special_tokens_extended(self) -> List[str]: + return [] + + @property + def all_special_tokens(self) -> List[str]: + return [] + + @property + def all_special_ids(self) -> List[int]: + return [] + + @property + def bos_token_id(self) -> int: + return self.tokenizer.bos_id + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_id + + @property + def is_fast(self) -> bool: + return True + + @property + def vocab_size(self) -> int: + return len(self._vocab) + + def __len__(self) -> int: + return self.vocab_size + + def __call__( + self, + prompt: str, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + # Mistral Tokenizers should not add special tokens + input_ids = self.encode(prompt) + + if truncation: + input_ids = input_ids[:max_length] + + return Encoding(input_ids=input_ids) + + def get_vocab(self) -> Dict[str, int]: + return self._vocab + + def get_added_vocab(self) -> Dict[str, int]: + # Mistral tokenizers have no added vocabulary + return {} + + def encode(self, prompt: str) -> List[int]: + # `encode` should only be used for prompt completion + # it should never be used for chat_completion. + # For chat completion use `apply_chat_template` + return self.tokenizer.encode(prompt, bos=True, eos=False) + + def apply_chat_template(self, + messages: List["ChatCompletionMessageParam"], + tools: Optional[Dict[str, Any]] = None, + **kwargs) -> List[int]: + + request = ChatCompletionRequest(messages=messages, + tools=tools) # type: ignore[type-var] + encoded = self.mistral.encode_chat_completion(request) + + # encode-decode to get clean prompt + return encoded.tokens + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + if isinstance(self.tokenizer, Tekkenizer): + tokens = [ + t for t in tokens + if t not in self.tokenizer._all_special_tokens + ] + + if any(isinstance(t, bytes) for t in tokens): + # we need to encode and decode all tokens again + shift = self.tokenizer.num_special_tokens + byte_tokens = [ + t.encode("utf-8") if not isinstance(t, bytes) else t + for t in tokens + ] + ids = [ + self.tokenizer._tekken_token2id_nospecial[t] + shift + for t in byte_tokens + ] + decoded = self.tokenizer.decode(ids) + else: + decoded = "".join(tokens) + else: + decoded = self.tokenizer.decode(tokens) # type: ignore[arg-type] + + return decoded + + def decode(self, ids: Union[List[int], int]) -> str: + if isinstance(ids, int): + ids = [ids] + return self.tokenizer.decode(ids) + + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: bool = True, + ) -> List[str]: + # TODO(Patrick) - potentially allow special tokens to not be skipped + assert ( + skip_special_tokens + ), "Skipping special tokens is not supported for Mistral tokenizers." + + assert isinstance(self.tokenizer, + (Tekkenizer, SentencePieceTokenizer)), type( + self.tokenizer) + + tokens = [self.tokenizer.id_to_piece(id) for id in ids] + + if any(t.strip() == "�" for t in tokens): + # if any stripped decoded token is undefined + # because it's invalid unicode then pass bytes + # See: https://github.com/vllm-project/vllm/pull/8640 + tokens = [self.tokenizer.id_to_byte_piece(id) for id in ids] + + return tokens diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py new file mode 100644 index 0000000..7a9041b --- /dev/null +++ b/vllm/transformers_utils/utils.py @@ -0,0 +1,16 @@ +from os import PathLike +from pathlib import Path +from typing import Union + + +def check_gguf_file(model: Union[str, PathLike]) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" diff --git a/vllm/triton_utils/__init__.py b/vllm/triton_utils/__init__.py new file mode 100644 index 0000000..4e19581 --- /dev/null +++ b/vllm/triton_utils/__init__.py @@ -0,0 +1,9 @@ +from vllm.triton_utils.importing import HAS_TRITON + +__all__ = ["HAS_TRITON"] + +#from vllm.triton_utils.custom_cache_manager import ( +# maybe_set_triton_cache_manager) +#from vllm.triton_utils.libentry import libentry + +__all__ += ["maybe_set_triton_cache_manager", "libentry"] diff --git a/vllm/triton_utils/__pycache__/__init__.cpython-310.pyc b/vllm/triton_utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9be72bc Binary files /dev/null and b/vllm/triton_utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-310.pyc b/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-310.pyc new file mode 100644 index 0000000..63d5cd3 Binary files /dev/null and b/vllm/triton_utils/__pycache__/custom_cache_manager.cpython-310.pyc differ diff --git a/vllm/triton_utils/__pycache__/importing.cpython-310.pyc b/vllm/triton_utils/__pycache__/importing.cpython-310.pyc new file mode 100644 index 0000000..81c44b1 Binary files /dev/null and b/vllm/triton_utils/__pycache__/importing.cpython-310.pyc differ diff --git a/vllm/triton_utils/__pycache__/libentry.cpython-310.pyc b/vllm/triton_utils/__pycache__/libentry.cpython-310.pyc new file mode 100644 index 0000000..5e74d2b Binary files /dev/null and b/vllm/triton_utils/__pycache__/libentry.cpython-310.pyc differ diff --git a/vllm/triton_utils/custom_cache_manager.py b/vllm/triton_utils/custom_cache_manager.py new file mode 100644 index 0000000..17039d7 --- /dev/null +++ b/vllm/triton_utils/custom_cache_manager.py @@ -0,0 +1,53 @@ +import os + +from triton.runtime.cache import (FileCacheManager, default_cache_dir, + default_dump_dir, default_override_dir) + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def maybe_set_triton_cache_manager() -> None: + """Set environment variable to tell Triton to use a + custom cache manager""" + cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) + if cache_manger is None: + manager = "vllm.triton_utils.custom_cache_manager:CustomCacheManager" + logger.info("Setting Triton cache manager to: %s", manager) + os.environ["TRITON_CACHE_MANAGER"] = manager + + +class CustomCacheManager(FileCacheManager): + """Re-implements Triton's cache manager, ensuring that a + unique cache directory is created for each process. This is + needed to avoid collisions when running with tp>1 and + using multi-processing as the distributed backend. + + Note this issue was fixed by triton-lang/triton/pull/4295, + but the fix is not yet included in triton==v3.0.0. However, + it should be included in the subsequent version. + """ + + def __init__(self, key, override=False, dump=False): + self.key = key + self.lock_path = None + if dump: + self.cache_dir = default_dump_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + elif override: + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) + else: + # create cache directory if it doesn't exist + self.cache_dir = os.getenv("TRITON_CACHE_DIR", + "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = f"{self.cache_dir}_{os.getpid()}" + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py new file mode 100644 index 0000000..760be67 --- /dev/null +++ b/vllm/triton_utils/importing.py @@ -0,0 +1,11 @@ +from importlib.util import find_spec + +from vllm.logger import init_logger + +logger = init_logger(__name__) + +HAS_TRITON = False + +if not HAS_TRITON: + logger.info("Triton not installed; certain GPU-related functions" + " will not be available.") diff --git a/vllm/triton_utils/libentry.py b/vllm/triton_utils/libentry.py new file mode 100644 index 0000000..4335c7a --- /dev/null +++ b/vllm/triton_utils/libentry.py @@ -0,0 +1,167 @@ +# Copied From https://github.com/FlagOpen/FlagGems + +import inspect + +import triton + + +class LibEntry(triton.KernelInterface): + + def __init__( + self, + fn, + ): + self.fn = fn + self.arg_names = fn.arg_names + self.divisibility = 16 + self.kernel_cache = dict() + fn = self.fn + while not isinstance(fn, triton.runtime.JITFunction): + fn = fn.fn + self.jit_function: triton.runtime.JITFunction = fn + self.specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and not p.do_not_specialize + ] + self.do_not_specialize_indices = [ + p.num for p in self.jit_function.params + if not p.is_constexpr and p.do_not_specialize + ] + + def key(self, spec_args, dns_args, const_args): + spec_key = [(arg.dtype, arg.data_ptr() % + self.divisibility == 0) if hasattr(arg, "data_ptr") else + (type(arg), arg) for arg in spec_args] + dns_key = [ + arg.dtype if hasattr( + arg, "data_ptr") else type(arg) if not isinstance(arg, int) + else "i32" if arg >= -(2**31) and arg <= 2**31 - + 1 else "u64" if arg >= 2**63 and arg <= 2**64 - 1 else "i64" + for arg in dns_args + ] + # const args passed by position + return tuple(spec_key + dns_key + const_args) + + def run(self, *args, **kwargs): + grid = kwargs["grid"] + # collect all the arguments + spec_args = [] # specialize arguments + dns_args = [] # do not specialize arguments + const_args = [] # constexpr arguments + k_args = [] # kernel arguments + for i, arg in enumerate(args): + if i in self.specialize_indices: + k_args.append(arg) + spec_args.append(arg) + elif i in self.do_not_specialize_indices: + k_args.append(arg) + dns_args.append(arg) + else: + const_args.append(arg) + for p in self.jit_function.params[len(args):]: + if p.name in kwargs: + val = kwargs[p.name] + elif p.default is inspect._empty: + continue + else: + val = p.default + + if p.is_constexpr: + const_args.append(val) + elif p.do_not_specialize: + dns_args.append(val) + k_args.append(val) + else: + spec_args.append(val) + k_args.append(val) + + entry_key = self.key(spec_args, dns_args, const_args) + + if entry_key not in self.kernel_cache: + # compile the kernel also completes the related computations + kernel = self.fn.run(*args, **kwargs) + fn = self.fn + # collect constexpr arguments for grid computation + constexprs = {} + while not isinstance(fn, triton.runtime.JITFunction): + if isinstance(fn, triton.runtime.Autotuner): + config = fn.best_config + constexprs["num_warps"] = config.num_warps + constexprs["num_stages"] = config.num_stages + constexprs["num_ctas"] = config.num_ctas + constexprs = {**constexprs, **config.kwargs} + elif isinstance(fn, triton.runtime.Heuristics): + for v, heur in fn.values.items(): + constexprs[v] = heur({ + **dict(zip(fn.arg_names, args)), + **kwargs, + **constexprs, + }) + else: + raise RuntimeError("Invalid Runtime Function") + fn = fn.fn + # In vLLM, certain kernels like fused_moe_kernel get the + # best_config(as kwargs) from a configuration json file, rather + # than using Autotuner & Heuristics. Therefore, all their constexprs + # (tl.constexpr) are assigned values through the following loop. + for p in self.jit_function.params: + if p.is_constexpr and p.name not in constexprs: + constexprs[p.name] = p.default #default=inspect._empty + self.kernel_cache[entry_key] = (kernel, constexprs) + else: + # load kernel from cache directly + kernel, constexprs = self.kernel_cache[entry_key] + + if callable(grid): + # collect all arguments to the grid fn,ie: + # 1. args, + # 2. kwargs, + # 3. all all other captured arguments in CompiledKernel from + # Autotunner & Heuristics when kwargs & captured args conflict, + # captured args have higher priority + # 4. We must filter out captured args with default value firstly + constexprs = { + k: v + for k, v in constexprs.items() if v is not inspect._empty + } + meta = { + **dict(zip(self.arg_names, args)), + **kwargs, + **constexprs, + } + grid = grid(meta) + if isinstance(grid, tuple): + grid = grid + (1, 1) + elif isinstance(grid, list): + grid = grid + [1, 1] + kernel[grid[0:3]](*k_args) + # maintaining the same return type as the JITFunction.run + return kernel + + +def libentry(): + """ + Decorator for triton library entries. + Motivation: + The runtime overhead of Triton kernels is the reason for the lower + performance of small kernels, particularly evident with smaller models. + Using this decorator can reduce Triton runtime overhead. + How: + The `run` function of JITFunction needs to accomplish: + - Parameter binding using inspect + - KernelArg type wrapping + - Cache key calculation + When dealing with small size, these steps can become bottlenecks in + Triton runtime. Libentry simplifies these steps to reduce runtime + overhead, thereby improving the runtime expenses of small kernels. + NOTE: + When Triton is upgraded to version 3.0.0, libentry can be removed, + see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245 + + + """ + + def decorator(fn): + return LibEntry(fn) + + return decorator diff --git a/vllm/usage/__init__.py b/vllm/usage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/usage/__pycache__/__init__.cpython-310.pyc b/vllm/usage/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9320f07 Binary files /dev/null and b/vllm/usage/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/usage/__pycache__/usage_lib.cpython-310.pyc b/vllm/usage/__pycache__/usage_lib.cpython-310.pyc new file mode 100644 index 0000000..4c0bd29 Binary files /dev/null and b/vllm/usage/__pycache__/usage_lib.cpython-310.pyc differ diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py new file mode 100644 index 0000000..9ae46ff --- /dev/null +++ b/vllm/usage/usage_lib.py @@ -0,0 +1,223 @@ +import datetime +import json +import logging +import os +import platform +import time +from enum import Enum +from pathlib import Path +from threading import Thread +from typing import Any, Dict, Optional, Union +from uuid import uuid4 + +import cpuinfo +import psutil +import requests +import torch + +import vllm.envs as envs +from vllm.connections import global_http_connection +from vllm.platforms import current_platform +from vllm.version import __version__ as VLLM_VERSION + +_config_home = envs.VLLM_CONFIG_ROOT +_USAGE_STATS_JSON_PATH = os.path.join(_config_home, "usage_stats.json") +_USAGE_STATS_DO_NOT_TRACK_PATH = os.path.join(_config_home, "do_not_track") +_USAGE_STATS_ENABLED = None +_USAGE_STATS_SERVER = envs.VLLM_USAGE_STATS_SERVER + +_GLOBAL_RUNTIME_DATA: Dict[str, Union[str, int, bool]] = {} + + +def set_runtime_usage_data(key: str, value: Union[str, int, bool]) -> None: + """Set global usage data that will be sent with every usage heartbeat.""" + _GLOBAL_RUNTIME_DATA[key] = value + + +def is_usage_stats_enabled(): + """Determine whether or not we can send usage stats to the server. + The logic is as follows: + - By default, it should be enabled. + - Three environment variables can disable it: + - VLLM_DO_NOT_TRACK=1 + - DO_NOT_TRACK=1 + - VLLM_NO_USAGE_STATS=1 + - A file in the home directory can disable it if it exists: + - $HOME/.config/vllm/do_not_track + """ + global _USAGE_STATS_ENABLED + if _USAGE_STATS_ENABLED is None: + do_not_track = envs.VLLM_DO_NOT_TRACK + no_usage_stats = envs.VLLM_NO_USAGE_STATS + do_not_track_file = os.path.exists(_USAGE_STATS_DO_NOT_TRACK_PATH) + + _USAGE_STATS_ENABLED = not (do_not_track or no_usage_stats + or do_not_track_file) + return _USAGE_STATS_ENABLED + + +def _get_current_timestamp_ns() -> int: + return int(datetime.datetime.now(datetime.timezone.utc).timestamp() * 1e9) + + +def _detect_cloud_provider() -> str: + # Try detecting through vendor file + vendor_files = [ + "/sys/class/dmi/id/product_version", "/sys/class/dmi/id/bios_vendor", + "/sys/class/dmi/id/product_name", + "/sys/class/dmi/id/chassis_asset_tag", "/sys/class/dmi/id/sys_vendor" + ] + # Mapping of identifiable strings to cloud providers + cloud_identifiers = { + "amazon": "AWS", + "microsoft corporation": "AZURE", + "google": "GCP", + "oraclecloud": "OCI", + } + + for vendor_file in vendor_files: + path = Path(vendor_file) + if path.is_file(): + file_content = path.read_text().lower() + for identifier, provider in cloud_identifiers.items(): + if identifier in file_content: + return provider + + # Try detecting through environment variables + env_to_cloud_provider = { + "RUNPOD_DC_ID": "RUNPOD", + } + for env_var, provider in env_to_cloud_provider.items(): + if os.environ.get(env_var): + return provider + + return "UNKNOWN" + + +class UsageContext(str, Enum): + UNKNOWN_CONTEXT = "UNKNOWN_CONTEXT" + LLM_CLASS = "LLM_CLASS" + API_SERVER = "API_SERVER" + OPENAI_API_SERVER = "OPENAI_API_SERVER" + OPENAI_BATCH_RUNNER = "OPENAI_BATCH_RUNNER" + ENGINE_CONTEXT = "ENGINE_CONTEXT" + + +class UsageMessage: + """Collect platform information and send it to the usage stats server.""" + + def __init__(self) -> None: + # NOTE: vLLM's server _only_ support flat KV pair. + # Do not use nested fields. + + self.uuid = str(uuid4()) + + # Environment Information + self.provider: Optional[str] = None + self.num_cpu: Optional[int] = None + self.cpu_type: Optional[str] = None + self.cpu_family_model_stepping: Optional[str] = None + self.total_memory: Optional[int] = None + self.architecture: Optional[str] = None + self.platform: Optional[str] = None + self.gpu_count: Optional[int] = None + self.gpu_type: Optional[str] = None + self.gpu_memory_per_device: Optional[int] = None + + # vLLM Information + self.model_architecture: Optional[str] = None + self.vllm_version: Optional[str] = None + self.context: Optional[str] = None + + # Metadata + self.log_time: Optional[int] = None + self.source: Optional[str] = None + + def report_usage(self, + model_architecture: str, + usage_context: UsageContext, + extra_kvs: Optional[Dict[str, Any]] = None) -> None: + t = Thread(target=self._report_usage_worker, + args=(model_architecture, usage_context, extra_kvs or {}), + daemon=True) + t.start() + + def _report_usage_worker(self, model_architecture: str, + usage_context: UsageContext, + extra_kvs: Dict[str, Any]) -> None: + self._report_usage_once(model_architecture, usage_context, extra_kvs) + self._report_continous_usage() + + def _report_usage_once(self, model_architecture: str, + usage_context: UsageContext, + extra_kvs: Dict[str, Any]) -> None: + # Platform information + if current_platform.is_cuda_alike(): + device_property = torch.cuda.get_device_properties(0) + self.gpu_count = torch.cuda.device_count() + self.gpu_type = device_property.name + self.gpu_memory_per_device = device_property.total_memory + self.provider = _detect_cloud_provider() + self.architecture = platform.machine() + self.platform = platform.platform() + self.total_memory = psutil.virtual_memory().total + + info = cpuinfo.get_cpu_info() + self.num_cpu = info.get("count", None) + self.cpu_type = info.get("brand_raw", "") + self.cpu_family_model_stepping = ",".join([ + str(info.get("family", "")), + str(info.get("model", "")), + str(info.get("stepping", "")) + ]) + + # vLLM information + self.context = usage_context.value + self.vllm_version = VLLM_VERSION + self.model_architecture = model_architecture + + # Metadata + self.log_time = _get_current_timestamp_ns() + self.source = envs.VLLM_USAGE_SOURCE + + data = vars(self) + if extra_kvs: + data.update(extra_kvs) + + self._write_to_file(data) + self._send_to_server(data) + + def _report_continous_usage(self): + """Report usage every 10 minutes. + + This helps us to collect more data points for uptime of vLLM usages. + This function can also help send over performance metrics over time. + """ + while True: + time.sleep(600) + data = { + "uuid": self.uuid, + "log_time": _get_current_timestamp_ns(), + } + data.update(_GLOBAL_RUNTIME_DATA) + + self._write_to_file(data) + self._send_to_server(data) + + def _send_to_server(self, data: Dict[str, Any]) -> None: + try: + global_http_client = global_http_connection.get_sync_client() + global_http_client.post(_USAGE_STATS_SERVER, json=data) + except requests.exceptions.RequestException: + # silently ignore unless we are using debug log + logging.debug("Failed to send usage data to server") + + def _write_to_file(self, data: Dict[str, Any]) -> None: + os.makedirs(os.path.dirname(_USAGE_STATS_JSON_PATH), exist_ok=True) + Path(_USAGE_STATS_JSON_PATH).touch(exist_ok=True) + with open(_USAGE_STATS_JSON_PATH, "a") as f: + json.dump(data, f) + f.write("\n") + + +usage_message = UsageMessage() diff --git a/vllm/utils.py b/vllm/utils.py new file mode 100644 index 0000000..edbe2d0 --- /dev/null +++ b/vllm/utils.py @@ -0,0 +1,1445 @@ +import argparse +import asyncio +import contextlib +import datetime +import enum +import gc +import inspect +import ipaddress +import os +import random +import socket +import subprocess +import sys +import tempfile +import threading +import uuid +import warnings +import weakref +from asyncio import FIRST_COMPLETED, ensure_future +from functools import lru_cache, partial, wraps +from platform import uname +from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, + Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, + Type, TypeVar, Union, overload) +from uuid import uuid4 + +import numpy as np +import numpy.typing as npt +import psutil +import torch +import torch.types +import yaml +from packaging.version import Version +from typing_extensions import ParamSpec, TypeIs, assert_never + +import vllm.envs as envs +from vllm.logger import enable_trace_function_call, init_logger +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +# Exception strings for non-implemented encoder/decoder scenarios + +# Reminder: Please update docs/source/serving/compatibility_matrix.rst +# If the feature combo become valid + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ + "Chunked prefill for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( + "Models with logits_soft_cap " + "require FlashInfer backend, which is " + "currently not supported for encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not " + "currently supported with " + "encoder/decoder models.") + +STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with " + "encoder/decoder models.") + +# Efficiently import all enc/dec error strings +# rather than having to import all of the above +STR_NOT_IMPL_ENC_DEC_ERR_STRS = { + "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA, + "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL": + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP, + "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA, + "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, + "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, + "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, + "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, + "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, + "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU +} + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + +GB_bytes = 1_000_000_000 +"""The number of bytes in one gigabyte (GB).""" + +GiB_bytes = 1 << 30 +"""The number of bytes in one gibibyte (GiB).""" + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, +} + +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + +P = ParamSpec('P') +K = TypeVar("K") +T = TypeVar("T") +U = TypeVar("U") + + +class _Sentinel: + ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + +class Device(enum.Enum): + GPU = enum.auto() + CPU = enum.auto() + + +class Counter: + + def __init__(self, start: int = 0) -> None: + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + +class LRUCache(Generic[T]): + + def __init__(self, capacity: int): + self.cache: OrderedDict[Hashable, T] = OrderedDict() + self.pinned_items: Set[Hashable] = set() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> T: + value = self.cache[key] # Raise KeyError if not exists + self.cache.move_to_end(key) + return value + + def __setitem__(self, key: Hashable, value: T) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: + value: Optional[T] + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: T) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def pin(self, key: Hashable) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self.cache: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: Hashable) -> None: + self.pinned_items.remove(key) + + def _on_remove(self, key: Hashable, value: Optional[T]): + pass + + def remove_oldest(self, remove_pinned=False): + if not self.cache: + return + + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.cache if key not in self.pinned_items), + ALL_PINNED_SENTINEL) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError("All items are pinned, " + "cannot remove oldest from the cache.") + else: + lru_key = next(iter(self.cache)) + self.pop(lru_key) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: + run_on_remove = key in self.cache + value: Optional[T] = self.cache.pop(key, default_value) + # remove from pinned items + if key in self.pinned_items: + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest(remove_pinned=True) + self.cache.clear() + + +class PyObjectCache: + """Used to cache python objects to avoid object allocations + across scheduler iterations. + """ + + def __init__(self, obj_builder): + self._obj_builder = obj_builder + self._index = 0 + + self._obj_cache = [] + for _ in range(128): + self._obj_cache.append(self._obj_builder()) + + def _grow_cache(self): + # Double the size of the cache + num_objs = len(self._obj_cache) + for _ in range(num_objs): + self._obj_cache.append(self._obj_builder()) + + def get_object(self): + """Returns a pre-allocated cached object. If there is not enough + objects, then the cache size will double. + """ + if self._index >= len(self._obj_cache): + self._grow_cache() + assert self._index < len(self._obj_cache) + + obj = self._obj_cache[self._index] + self._index += 1 + + return obj + + def reset(self): + """Makes all cached-objects available for the next scheduler iteration. + """ + self._index = 0 + + +def is_hip() -> bool: + return torch.version.hip is not None + + +@lru_cache(maxsize=None) +def is_cpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "cpu" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_openvino() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "openvino" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_neuron() -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + +@lru_cache(maxsize=None) +def is_xpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + is_xpu_flag = "xpu" in version("vllm") + except PackageNotFoundError: + return False + # vllm is not build with xpu + if not is_xpu_flag: + return False + try: + import intel_extension_for_pytorch as ipex # noqa: F401 + _import_ipex = True + except ImportError as e: + logger.warning("Import Error for IPEX: %s", e.msg) + _import_ipex = False + # ipex dependency is not ready + if not _import_ipex: + logger.warning("not found ipex lib") + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache(maxsize=None) +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + from vllm import _custom_ops as ops + max_shared_mem = ( + ops.get_max_shared_memory_per_block_device_attribute(gpu)) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py + # will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +def seed_everything(seed: int) -> None: + """ + Set the seed of each random module. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + + if current_platform.is_cuda_alike(): + torch.cuda.manual_seed_all(seed) + + if is_xpu(): + torch.xpu.manual_seed_all(seed) + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +@lru_cache(maxsize=None) +def get_vllm_instance_id() -> str: + """ + If the environment variable VLLM_INSTANCE_ID is set, return it. + Otherwise, return a random UUID. + Instance id represents an instance of the VLLM. All processes in the same + instance should have the same instance id. + """ + return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" + + +@lru_cache(maxsize=None) +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() + + +def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper + + +async def iterate_with_cancellation( + iterator: AsyncGenerator[T, None], + is_cancelled: Callable[[], Awaitable[bool]], +) -> AsyncGenerator[T, None]: + """Convert async iterator into one that polls the provided function + at least once per second to check for client cancellation. + """ + + # Can use anext() in python >= 3.10 + awaits = [ensure_future(iterator.__anext__())] + while True: + done, pending = await asyncio.wait(awaits, timeout=1) + if await is_cancelled(): + with contextlib.suppress(BaseException): + awaits[0].cancel() + await iterator.aclose() + raise asyncio.CancelledError("client cancelled") + if done: + try: + item = await awaits[0] + awaits[0] = ensure_future(iterator.__anext__()) + yield item + except StopAsyncIteration: + # we are done + return + + +async def merge_async_iterators( + *iterators: AsyncGenerator[T, None], + is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, +) -> AsyncGenerator[Tuple[int, T], None]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + + It also optionally polls a provided function at least once per second + to check for client cancellation. + """ + + # Can use anext() in python >= 3.10 + awaits = { + ensure_future(pair[1].__anext__()): pair + for pair in enumerate(iterators) + } + timeout = None if is_cancelled is None else 1 + try: + while awaits: + done, pending = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED, + timeout=timeout) + if is_cancelled is not None and await is_cancelled(): + raise asyncio.CancelledError("client cancelled") + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[ensure_future(it.__anext__())] = pair + yield i, item + except StopAsyncIteration: + pass + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + f.cancel() + await it.aclose() + + +async def collect_from_async_generator( + iterator: AsyncGenerator[T, None]) -> List[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items + + +def get_ip() -> str: + host_ip = envs.VLLM_HOST_IP + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", + stacklevel=2) + return "0.0.0.0" + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def get_distributed_init_method(ip: str, port: int) -> str: + # Brackets are not permitted in ipv4 addresses, + # see https://github.com/python/cpython/issues/103848 + return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" + + +def get_open_zmq_ipc_path() -> str: + base_rpc_path = envs.VLLM_RPC_BASE_PATH + return f"ipc://{base_rpc_path}/{uuid4()}" + + +def get_open_port() -> int: + port = envs.VLLM_PORT + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", + port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def find_process_using_port(port: int) -> Optional[psutil.Process]: + for conn in psutil.net_connections(): + if conn.laddr.port == port: + try: + return psutil.Process(conn.pid) + except psutil.NoSuchProcess: + return None + return None + + +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " + "from '%s' to '%s'", k, os.environ[k], v) + os.environ[k] = v + + +def chunk_list(lst: List[T], chunk_size: int): + """Yield successive chunk_size chunks from lst.""" + for i in range(0, len(lst), chunk_size): + yield lst[i:i + chunk_size] + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + #-----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + + key_caches: List[torch.Tensor] = [] + value_caches: List[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + + seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +@lru_cache +def print_warning_once(msg: str) -> None: + # Set the stacklevel to 2 to print the caller's line info + logger.warning(msg, stacklevel=2) + + +@lru_cache(maxsize=None) +def is_pin_memory_available() -> bool: + + if in_wsl(): + # Pinning memory in WSL is not supported. + # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications + print_warning_once("Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance.") + return False + elif is_xpu(): + print_warning_once("Pin memory is not supported on XPU.") + return False + elif is_neuron(): + print_warning_once("Pin memory is not supported on Neuron.") + return False + elif is_cpu() or is_openvino(): + return False + return True + + +class DeviceMemoryProfiler: + + def __init__(self, device: Optional[torch.types.Device] = None): + self.device = device + + def current_memory_usage(self) -> float: + # Return the memory usage in bytes. + if current_platform.is_cuda_alike(): + torch.cuda.reset_peak_memory_stats(self.device) + mem = torch.cuda.max_memory_allocated(self.device) + elif is_xpu(): + torch.xpu.reset_peak_memory_stats(self.device) # type: ignore + mem = torch.xpu.max_memory_allocated(self.device) # type: ignore + return mem + + def __enter__(self): + self.initial_memory = self.current_memory_usage() + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage() + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() + + +def make_ndarray_with_pad( + x: List[List[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: Optional[int] = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: List[List[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: Union[str, torch.device], + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# `collections` helpers +def is_list_of( + value: object, + typ: Type[T], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[List[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + +JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], + Tuple["JSONTree[T]", ...], T] +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: Dict[str, JSONTree[T]], +) -> Dict[str, JSONTree[U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: List[JSONTree[T]], +) -> List[JSONTree[U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: Tuple[JSONTree[T], ...], +) -> Tuple[JSONTree[U], ...]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: JSONTree[T], +) -> JSONTree[U]: + ... + + +def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: + if isinstance(value, dict): + return {k: json_map_leaves(func, v) for k, v in value.items()} + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + +def flatten_2d_lists(lists: List[List[T]]) -> List[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + +# TODO: This function can be removed if transformer_modules classes are +# serialized by value when communicating between processes +def init_cached_hf_modules() -> None: + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() + + +@lru_cache(maxsize=None) +def find_library(lib_name: str) -> str: + """ + Find the library file in the system. + `lib_name` is full filename, with both prefix and suffix. + This function resolves `lib_name` to the full path of the library. + """ + # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa + # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard + # `/sbin/ldconfig` should exist in all Linux systems. + # `/sbin/ldconfig` searches the library in the system + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] + # `LD_LIBRARY_PATH` searches the library in the user-defined paths + env_ld_library_path = envs.LD_LIBRARY_PATH + if not locs and env_ld_library_path: + locs = [ + os.path.join(dir, lib_name) + for dir in env_ld_library_path.split(":") + if os.path.exists(os.path.join(dir, lib_name)) + ] + if not locs: + raise ValueError(f"Cannot find {lib_name} in the system.") + return locs[0] + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = envs.VLLM_NCCL_SO_PATH + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +def enable_trace_function_call_for_thread() -> None: + """Set up function tracing for the current thread, + if enabled via the VLLM_TRACE_FUNCTION environment variable + """ + + if envs.VLLM_TRACE_FUNCTION: + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + + +# `functools` helpers +def identity(value: T) -> T: + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless( + cuda_visible_devices: Optional[str] = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + if not torch.cuda._is_compiled(): + return 0 + if is_hip(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = torch.cuda._device_count_amdsmi() if (hasattr( + torch.cuda, "_device_count_amdsmi")) else -1 + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def cuda_is_initialized() -> bool: + """Check if CUDA is initialized.""" + if not torch.cuda._is_compiled(): + return False + return torch.cuda.is_initialized() + + +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + +#From: https://stackoverflow.com/a/4104188/2749989 +def run_once(f: Callable[P, None]) -> Callable[P, None]: + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + return wrapper + + +class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that allows both underscore and dash in names.""" + + def parse_args(self, args=None, namespace=None): + if args is None: + args = sys.argv[1:] + + if '--config' in args: + args = FlexibleArgumentParser._pull_args_from_config(args) + + # Convert underscores to dashes and vice versa in argument names + processed_args = [] + for arg in args: + if arg.startswith('--'): + if '=' in arg: + key, value = arg.split('=', 1) + key = '--' + key[len('--'):].replace('_', '-') + processed_args.append(f'{key}={value}') + else: + processed_args.append('--' + + arg[len('--'):].replace('_', '-')) + else: + processed_args.append(arg) + + return super().parse_args(processed_args, namespace) + + @staticmethod + def _pull_args_from_config(args: List[str]) -> List[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + assert args.count( + '--config') <= 1, "More than one config file specified!" + + index = args.index('--config') + if index == len(args) - 1: + raise ValueError("No config file specified! \ + Please check your command-line arguments.") + + file_path = args[index + 1] + + config_args = FlexibleArgumentParser._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # followed by model_tag (only for serve) + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + if args[0] == "serve": + if index == 1: + raise ValueError( + "No model_tag specified! Please check your command-line" + " arguments.") + args = [args[0]] + [ + args[1] + ] + config_args + args[2:index] + args[index + 2:] + else: + args = [args[0]] + config_args + args[1:index] + args[index + 2:] + + return args + + @staticmethod + def _load_config_file(file_path: str) -> List[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tensor-parallel-size': '4' + ] + + """ + + extension: str = file_path.split('.')[-1] + if extension not in ('yaml', 'yml'): + raise ValueError( + "Config file must be of a yaml/yml type.\ + %s supplied", extension) + + # only expecting a flat dictionary of atomic types + processed_args: List[str] = [] + + config: Dict[str, Union[int, str]] = {} + try: + with open(file_path, 'r') as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", file_path) + raise ex + + for key, value in config.items(): + processed_args.append('--' + key) + processed_args.append(str(value)) + + return processed_args + + +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, + **kwargs): + """Utility function to run async task in a lock""" + async with lock: + return await task(*args, **kwargs) + + +def supports_kw( + callable: Callable[..., object], + kw_name: str, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ + params = inspect.signature(callable).parameters + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if (requires_kw_only and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + return False + if ((requires_kw_only + and param_val.kind == inspect.Parameter.KEYWORD_ONLY) + or (not requires_kw_only and is_sig_param)): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return (last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name) + return False + + +def resolve_mm_processor_kwargs( + init_kwargs: Optional[Dict[str, Any]], + inference_kwargs: Optional[Dict[str, Any]], + callable: Callable[..., object], + allow_var_kwargs: bool = False, +) -> Dict[str, Any]: + """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., + those who are not explicit keywords to the given callable (of one is + given; otherwise no filtering is done), then merges the kwarg dicts, + giving priority to inference_kwargs if there are any collisions. + + In the case that no kwarg overrides are provided, returns an empty + dict so that it can still be kwarg expanded into the callable later on. + + If allow_var_kwargs=True, allows for things that can be expanded into + kwargs as long as they aren't naming collision for var_kwargs or potential + positional arguments. + """ + # Filter inference time multimodal processor kwargs provided + runtime_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, + overrides=inference_kwargs, + allow_var_kwargs=allow_var_kwargs) + + # Filter init time multimodal processor kwargs provided + init_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs) + + # Merge the final processor kwargs, prioritizing inference + # time values over the initialization time values. + mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs + + +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Optional[Dict[str, Any]], + allow_var_kwargs: bool = False, +) -> Dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. + overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if supports_kw(callable, + kwarg_name, + requires_kw_only=True, + allow_var_kwargs=allow_var_kwargs) + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + logger.warning( + "The following intended overrides are not keyword-only args " + "and and will be dropped: %s", dropped_keys) + + return filtered_overrides + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial=0): + """Initialize a new atomic counter to given initial value""" + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value + + @property + def value(self): + return self._value diff --git a/vllm/version.py b/vllm/version.py new file mode 100644 index 0000000..135a818 --- /dev/null +++ b/vllm/version.py @@ -0,0 +1,2 @@ +__version__ = "0.6.3" +__version_tuple__ = (0, 6, 3) diff --git a/vllm/worker/__init__.py b/vllm/worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/worker/__pycache__/__init__.cpython-310.pyc b/vllm/worker/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000..9f2330d Binary files /dev/null and b/vllm/worker/__pycache__/__init__.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/cache_engine.cpython-310.pyc b/vllm/worker/__pycache__/cache_engine.cpython-310.pyc new file mode 100644 index 0000000..ba2d296 Binary files /dev/null and b/vllm/worker/__pycache__/cache_engine.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-310.pyc new file mode 100644 index 0000000..29b091e Binary files /dev/null and b/vllm/worker/__pycache__/cpu_enc_dec_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/cpu_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/cpu_model_runner.cpython-310.pyc new file mode 100644 index 0000000..44e5a90 Binary files /dev/null and b/vllm/worker/__pycache__/cpu_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/cpu_worker.cpython-310.pyc b/vllm/worker/__pycache__/cpu_worker.cpython-310.pyc new file mode 100644 index 0000000..c85339b Binary files /dev/null and b/vllm/worker/__pycache__/cpu_worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/embedding_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/embedding_model_runner.cpython-310.pyc new file mode 100644 index 0000000..3803bc6 Binary files /dev/null and b/vllm/worker/__pycache__/embedding_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/enc_dec_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/enc_dec_model_runner.cpython-310.pyc new file mode 100644 index 0000000..b125020 Binary files /dev/null and b/vllm/worker/__pycache__/enc_dec_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/model_runner.cpython-310.pyc b/vllm/worker/__pycache__/model_runner.cpython-310.pyc new file mode 100644 index 0000000..a8cef01 Binary files /dev/null and b/vllm/worker/__pycache__/model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/model_runner_base.cpython-310.pyc b/vllm/worker/__pycache__/model_runner_base.cpython-310.pyc new file mode 100644 index 0000000..3706030 Binary files /dev/null and b/vllm/worker/__pycache__/model_runner_base.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/multi_step_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/multi_step_model_runner.cpython-310.pyc new file mode 100644 index 0000000..6a57eb8 Binary files /dev/null and b/vllm/worker/__pycache__/multi_step_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-310.pyc b/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-310.pyc new file mode 100644 index 0000000..1ef4b45 Binary files /dev/null and b/vllm/worker/__pycache__/multi_step_tpu_worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/multi_step_worker.cpython-310.pyc b/vllm/worker/__pycache__/multi_step_worker.cpython-310.pyc new file mode 100644 index 0000000..c55ecd0 Binary files /dev/null and b/vllm/worker/__pycache__/multi_step_worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/neuron_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/neuron_model_runner.cpython-310.pyc new file mode 100644 index 0000000..3740c52 Binary files /dev/null and b/vllm/worker/__pycache__/neuron_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/neuron_worker.cpython-310.pyc b/vllm/worker/__pycache__/neuron_worker.cpython-310.pyc new file mode 100644 index 0000000..5fac9ad Binary files /dev/null and b/vllm/worker/__pycache__/neuron_worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/openvino_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/openvino_model_runner.cpython-310.pyc new file mode 100644 index 0000000..6237b1c Binary files /dev/null and b/vllm/worker/__pycache__/openvino_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/openvino_worker.cpython-310.pyc b/vllm/worker/__pycache__/openvino_worker.cpython-310.pyc new file mode 100644 index 0000000..f6fa42e Binary files /dev/null and b/vllm/worker/__pycache__/openvino_worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/tpu_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/tpu_model_runner.cpython-310.pyc new file mode 100644 index 0000000..3a70092 Binary files /dev/null and b/vllm/worker/__pycache__/tpu_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/tpu_worker.cpython-310.pyc b/vllm/worker/__pycache__/tpu_worker.cpython-310.pyc new file mode 100644 index 0000000..e76dea5 Binary files /dev/null and b/vllm/worker/__pycache__/tpu_worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/utils.cpython-310.pyc b/vllm/worker/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000..a32a86a Binary files /dev/null and b/vllm/worker/__pycache__/utils.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/worker.cpython-310.pyc b/vllm/worker/__pycache__/worker.cpython-310.pyc new file mode 100644 index 0000000..2ea3809 Binary files /dev/null and b/vllm/worker/__pycache__/worker.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/worker_base.cpython-310.pyc b/vllm/worker/__pycache__/worker_base.cpython-310.pyc new file mode 100644 index 0000000..b3bdfef Binary files /dev/null and b/vllm/worker/__pycache__/worker_base.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/xpu_model_runner.cpython-310.pyc b/vllm/worker/__pycache__/xpu_model_runner.cpython-310.pyc new file mode 100644 index 0000000..8130605 Binary files /dev/null and b/vllm/worker/__pycache__/xpu_model_runner.cpython-310.pyc differ diff --git a/vllm/worker/__pycache__/xpu_worker.cpython-310.pyc b/vllm/worker/__pycache__/xpu_worker.cpython-310.pyc new file mode 100644 index 0000000..7b6f418 Binary files /dev/null and b/vllm/worker/__pycache__/xpu_worker.cpython-310.pyc differ diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py new file mode 100644 index 0000000..090f95e --- /dev/null +++ b/vllm/worker/cache_engine.py @@ -0,0 +1,120 @@ +"""CacheEngine class for managing the KV cache.""" +from typing import List + +import torch + +from vllm.attention import get_attn_backend +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig +from vllm.logger import init_logger +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size, + is_pin_memory_available) + +logger = init_logger(__name__) + + +class CacheEngine: + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU and CPU KV + caches. It also provides methods for performing KV cache operations, such + as swapping and copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + self.device_config = device_config + + self.head_size = model_config.get_head_size() + # Models like Jamba, have mixed typed layers, E.g Mamba + self.num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + if self.num_gpu_blocks: + self.num_gpu_blocks //= parallel_config.pipeline_parallel_size + self.num_cpu_blocks = cache_config.num_cpu_blocks + if self.num_cpu_blocks: + self.num_cpu_blocks //= parallel_config.pipeline_parallel_size + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Get attention backend. + self.attn_backend = get_attn_backend(self.head_size, + model_config.get_sliding_window(), + model_config.dtype, + cache_config.cache_dtype, + self.block_size, + model_config.is_attention_free) + + # Initialize the cache. + self.gpu_cache = self._allocate_kv_cache( + self.num_gpu_blocks, self.device_config.device_type) + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu") + + def _allocate_kv_cache( + self, + num_blocks: int, + device: str, + ) -> List[torch.Tensor]: + """Allocates KV cache on the specified device.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size) + pin_memory = is_pin_memory_available() if device == "cpu" else False + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_attention_layers): + # null block in CpuGpuBlockAllocator requires at least that + # block to be zeroed-out. + # We zero-out everything for simplicity. + kv_cache.append( + torch.zeros(kv_cache_shape, + dtype=self.dtype, + pin_memory=pin_memory, + device=device)) + return kv_cache + + def swap_in(self, src_to_dst: torch.Tensor) -> None: + for i in range(self.num_attention_layers): + self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], + src_to_dst) + + def swap_out(self, src_to_dst: torch.Tensor) -> None: + for i in range(self.num_attention_layers): + self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], + src_to_dst) + + def copy(self, src_to_dsts: torch.Tensor) -> None: + self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_attention_layers = model_config.get_num_attention_layers( + parallel_config) + + key_cache_block = cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_attention_layers * (key_cache_block + value_cache_block) + if cache_config.cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + dtype_size = get_dtype_size(dtype) + return dtype_size * total diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py new file mode 100644 index 0000000..8ebbf6d --- /dev/null +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -0,0 +1,311 @@ +import dataclasses +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast + +import torch + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MultiModalInputs +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import (CPUModelRunner, + ModelInputForCPUBuilder, + ModelInputForCPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInputForCPU": + return cast( + EncoderDecoderModelInputForCPU, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class CPUEncoderDecoderModelRunner(CPUModelRunner): + _model_input_cls: Type[EncoderDecoderModelInputForCPU] = ( + EncoderDecoderModelInputForCPU) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, + Any]) -> EncoderDecoderModelInputForCPU: + return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInputForCPU: + model_input = super().prepare_model_input(seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + model_input = cast(EncoderDecoderModelInputForCPU, model_input) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input) + return dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInputForCPU, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max_len_of_block_table, + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + cross_slot_mapping_tensor, + cross_block_tables, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) + + @torch.no_grad() + def execute_model( + self, + model_input: EncoderDecoderModelInputForCPU, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "encoder_input_ids": + model_input.encoder_input_tokens, + "encoder_positions": + model_input.encoder_input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + "intermediate_tensors": + intermediate_tensors, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py new file mode 100644 index 0000000..795511a --- /dev/null +++ b/vllm/worker/cpu_model_runner.py @@ -0,0 +1,549 @@ +import dataclasses +import weakref +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs) +from vllm.sequence import (IntermediateTensors, SequenceData, + SequenceGroupMetadata) +from vllm.utils import make_tensor_with_pad +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 + + +@dataclass(frozen=True) +class ModelInputForCPU(ModelRunnerInputBase): + """ + Base class contains metadata needed for the base model forward pass on CPU + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForCPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None + ) -> "ModelInputForCPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForCPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): + + def __init__(self, + runner: "CPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def build(self) -> ModelInputForCPU: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = None + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + # query_lens is not needed if chunked prefill is not + # supported. Since CPU worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + def _compute_multi_modal_input(self, seq_data: SequenceData, mm_data, + computed_len: int, + mm_processor_kwargs: Dict[str, Any]): + mm_kwargs = self.multi_modal_input_mapper(mm_data, mm_processor_kwargs) + + # special processing for mrope position deltas. + mrope_positions = None + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + token_ids = seq_data.get_token_ids() + + mrope_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=computed_len, + ) + seq_data.mrope_position_delta = mrope_position_delta + return mm_kwargs, mrope_positions + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] + + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + mrope_positions = None + if (mm_data := seq_group_metadata.multi_modal_data): + mm_kwargs, mrope_positions = self._compute_multi_modal_input( + seq_data, mm_data, computed_len, + seq_group_metadata.mm_processor_kwargs) + multi_modal_inputs_list.append(mm_kwargs) + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + if mrope_positions: + for idx in range(3): + input_mrope_positions[idx].extend(mrope_positions[idx]) + else: + input_positions.extend(list(range(computed_len, seq_len))) + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions + or input_mrope_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + seq_lens=seq_lens, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([]), + slot_mapping=slot_mapping, + ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + input_mrope_positions: List[List[int]] = [[] for _ in range(3)] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + if seq_data.mrope_position_delta is not None: + context_len = seq_data.get_num_computed_tokens() + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + input_mrope_positions[idx].extend(next_pos[idx]) + else: + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + if any(input_mrope_positions): + input_positions = None # type: ignore + else: + input_mrope_positions = None # type: ignore + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions + or input_mrope_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + block_tables = make_tensor_with_pad( + block_tables, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + +class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]): + _model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = ( + ModelInputForCPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + # Currently, CPU worker doesn't support chunked prefill. + assert self.scheduler_config.chunked_prefill_enabled is False + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Multi-modal data support + self.mm_registry = MULTIMODAL_REGISTRY + self.multi_modal_input_mapper = self.mm_registry \ + .create_input_mapper(self.model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + + def load_model(self) -> None: + self.model = get_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForCPUWithSamplingMetadata: + return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501 + tensor_dict, + attn_backend=self.attn_backend, + ) + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForCPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForCPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "CPU worker does not support multi-step execution.") + + model_executable = self.model + execute_model_kwargs = { + "input_ids": + model_input.input_tokens, + "positions": + model_input.input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + model_input.attn_metadata, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + "intermediate_tensors": + intermediate_tensors, + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py new file mode 100644 index 0000000..b845628 --- /dev/null +++ b/vllm/worker/cpu_worker.py @@ -0,0 +1,371 @@ +"""A CPU worker class.""" +from typing import Dict, List, Optional, Tuple, Type + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE +from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner +from vllm.worker.cpu_model_runner import CPUModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) + +logger = init_logger(__name__) + + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + + +class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a CPU socket. + + Each worker is associated with a single CPU socket. The worker is + responsible for maintaining the KV cache and executing the model on the + CPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Setup OpenMP threads affinity. + omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND + if omp_cpuids == "all": + self.local_omp_cpuid = "all" + else: + self.local_omp_cpuid = omp_cpuids.split("|")[rank] + + ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner + if self._is_encoder_decoder_model(): + ModelRunnerClass = CPUEncoderDecoderModelRunner + self.model_runner: CPUModelRunner = ModelRunnerClass( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, + is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CPUCacheEngine] + self.cpu_cache: List[List[torch.Tensor]] + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def _is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model + + def init_device(self) -> None: + if self.local_omp_cpuid != "all": + ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + if ret: + logger.info(ret) + + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = [ + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.cpu_cache = [ + self.cache_engine[ve].cpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + self.model_runner.block_size = self.cache_engine[0].block_size + + assert all( + self.cpu_cache[ve] is not None + for ve in range(self.parallel_config.pipeline_parallel_size)) + + # Populate the cache to warmup the memory + for ve in range(self.parallel_config.pipeline_parallel_size): + for layer_cache in self.cpu_cache[ve]: + layer_cache.fill_(0) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.cpu_cache + + def execute_worker( + self, + worker_input: WorkerInput, + ) -> None: + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[worker_input.virtual_engine].copy( + worker_input.blocks_to_copy) + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + assert execute_model_req is not None + virtual_engine = execute_model_req.virtual_engine + num_seq_groups: int = len(execute_model_req.seq_group_metadata_list) + blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device="cpu", + dtype=torch.int64).view(-1, 2) + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size( + self.cache_config.block_size, self.cache_config.cache_dtype, + self.model_config, self.parallel_config) diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py new file mode 100644 index 0000000..a7f5b2d --- /dev/null +++ b/vllm/worker/embedding_model_runner.py @@ -0,0 +1,208 @@ +import dataclasses +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.multimodal import MultiModalInputs +from vllm.pooling_params import PoolingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, + SequenceGroupMetadata) +from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPU, + ModelInputForGPUBuilder) + +logger = init_logger(__name__) + + +@dataclasses.dataclass(frozen=True) +class ModelInputForGPUWithPoolingMetadata(ModelInputForGPU): + """ + Used by the EmbeddingModelRunner. + """ + pooling_metadata: Optional["PoolingMetadata"] = None + + +class EmbeddingModelRunner( + GPUModelRunnerBase[ModelInputForGPUWithPoolingMetadata]): + _model_input_cls: Type[ModelInputForGPUWithPoolingMetadata] = ( + ModelInputForGPUWithPoolingMetadata) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, + ): + super().__init__(model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + observability_config=observability_config) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithPoolingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError( + "EmbeddingModelRunner does not support multi-step execution.") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start = torch.cuda.Event(enable_timing=True) + model_forward_end = torch.cuda.Event(enable_timing=True) + model_forward_start.record() + + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device)) + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.record() + + # Only perform pooling in the last pipeline stage. + if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) + return hidden_or_intermediate_states + + # Only perform pooling in the driver worker. + if not self.is_driver_worker: + return [] + + return [ + self.model.pooler(hidden_states=hidden_or_intermediate_states, + pooling_metadata=model_input.pooling_metadata) + ] + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForGPUWithPoolingMetadata: + return ModelInputForGPUWithPoolingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForGPUWithPoolingMetadata: + assert seq_group_metadata_list is not None + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Prepare PoolingMetadata. + assert model_input.seq_lens is not None + pooling_metadata = self._prepare_pooling(seq_group_metadata_list, + model_input.seq_lens) + + return dataclasses.replace(model_input, + pooling_metadata=pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py new file mode 100644 index 0000000..d58a7a1 --- /dev/null +++ b/vllm/worker/enc_dec_model_runner.py @@ -0,0 +1,546 @@ +import dataclasses +import itertools +from typing import Any, Dict, List, Optional, Tuple, Type, cast + +import torch +import torch.distributed + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.attention.selector import (_Backend, get_env_variable_attn_backend, + get_global_forced_attn_backend, + global_force_attn_backend) +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, + MultiModalRegistry) +from vllm.sampling_params import SamplingParams +from vllm.sequence import (IntermediateTensors, PoolerOutput, + SequenceGroupMetadata) +from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata, + _get_graph_batch_size) +from vllm.worker.model_runner_base import ( + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict) +from vllm.worker.utils import assert_enc_dec_mr_supported_scenario + +logger = init_logger(__name__) + + +@dataclasses.dataclass(frozen=True) +class EncoderDecoderModelInput(ModelInputForGPUWithSamplingMetadata): + """ + Used by the EncoderDecoderModelRunner. + """ + encoder_input_tokens: Optional[torch.Tensor] = None + encoder_input_positions: Optional[torch.Tensor] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "encoder_input_tokens": self.encoder_input_tokens, + "encoder_input_positions": self.encoder_input_positions, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "EncoderDecoderModelInput": + return cast( + EncoderDecoderModelInput, + super().from_broadcasted_tensor_dict(tensor_dict, attn_backend)) + + +class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]): + _model_input_cls: Type[EncoderDecoderModelInput] = ( + EncoderDecoderModelInput) + _builder_cls: Type[ModelInputForGPUBuilder] = (ModelInputForGPUBuilder) + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + ''' + EncoderDecoderModelRunner constructor. + + `lora_config` and `prompt_adapter_config` are + unused (since these features are not yet supported for encoder/decoder + models) but these arguments are present here for compatibility with + the base-class constructor. + ''' + + self._maybe_force_supported_attention_backend() + + super().__init__( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + lora_config=None, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + ) + + # Crash for unsupported encoder/scenarios + assert_enc_dec_mr_supported_scenario(self) + + def _maybe_force_supported_attention_backend(self): + ''' + Force vLLM to use the XFormers attention backend, + which is currently the only supported option. + ''' + + def raise_backend_err(): + # The user has specified an attention backend override + # which is invalid for encoder/decoder models + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_BACKEND) + + maybe_env_var_forced_backend = get_env_variable_attn_backend() + maybe_global_forced_backend = get_global_forced_attn_backend() + is_forced_by_global = maybe_global_forced_backend is not None + is_forced_by_env_var = maybe_env_var_forced_backend is not None + + if not (is_forced_by_global or is_forced_by_env_var): + # The user has not already specified an attention backend + # override + pass + # logger.info("EncoderDecoderModelRunner requires " + # "XFormers backend; overriding backend " + # "auto-selection and forcing XFormers.") + # global_force_attn_backend(_Backend.XFORMERS) + elif is_forced_by_global: + # Backend override enforced by global variable takes + # precedence over vLLM backend environment variable. + if maybe_global_forced_backend != _Backend.XFORMERS: + raise_backend_err() + elif is_forced_by_env_var: + # Backend override enforced by vLLM backend + # environment variable + if maybe_env_var_forced_backend != _Backend.XFORMERS: + raise_backend_err() + + def _list_to_int32_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.int32, device=self.device) + + def _list_to_long_tensor( + self, + _list: List[int], + ) -> torch.Tensor: + return torch.tensor(_list, dtype=torch.long, device=self.device) + + def _empty_int32_tensor(self) -> torch.Tensor: + return self._list_to_int32_tensor([]) + + def _empty_long_tensor(self) -> torch.Tensor: + return self._list_to_long_tensor([]) + + @torch.inference_mode() + def execute_model( + self, + model_input: EncoderDecoderModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[PoolerOutput]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in " + "EncoderDecoderModelRunner") + + if (model_input.attn_metadata is not None + and model_input.attn_metadata.prefill_metadata is None + and model_input.attn_metadata.decode_metadata.use_cuda_graph): + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[ + model_input.virtual_engine][graph_batch_size] + else: + model_executable = self.model + + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + return [output] + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> EncoderDecoderModelInput: + return EncoderDecoderModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> EncoderDecoderModelInput: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + Since chunked prefill is not supported for encoder/decoder models, + `input_tokens` is assumed to be either entirely prefill tokens or + entirely decode tokens. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + ( + attn_metadata, + encoder_input_tokens_tensor, + encoder_input_positions_tensor, + ) = (self._prepare_encoder_model_input_tensors(seq_group_metadata_list, + model_input)) + # Inject attn_metadata encoder/cross-attention fields & + # encoder input tokens/positions into model_input. + # Frozen dataclass fields cannot be modified, so use + # dataclasses.replace to construct a new model input + # instance. + model_input = dataclasses.replace( + model_input, + attn_metadata=attn_metadata, + encoder_input_tokens=encoder_input_tokens_tensor, + encoder_input_positions=encoder_input_positions_tensor, + ) + + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + self.pin_memory, + generators=generators) + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + logger.info("Starting profile run for multi-modal models.") + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + decoder_seq_data, decoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=False) + encoder_seq_data, encoder_dummy_multi_modal_data \ + = self.input_registry.dummy_data_for_profiling( + self.model_config, + seq_len, + self.mm_registry, + is_encoder_data=True) + + # Having more tokens is over-conservative but otherwise fine + assert len(decoder_seq_data.prompt_token_ids) >= seq_len, ( + f"Expected at least {seq_len} dummy tokens for profiling, " + f"but got: {len(decoder_seq_data.prompt_token_ids)}") + + assert decoder_dummy_multi_modal_data is None or \ + encoder_dummy_multi_modal_data is None, ( + "Multi-modal data can't be provided in both encoder and decoder" + ) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: decoder_seq_data}, + sampling_params=sampling_params, + block_tables=None, + encoder_seq_data=encoder_seq_data, + cross_block_table=None, + multi_modal_data=decoder_dummy_multi_modal_data + or encoder_dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + return + + def _prepare_encoder_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + model_input: EncoderDecoderModelInput, + ) -> Tuple[AttentionMetadata, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """Helper method to prepare the encoder- and cross-attn-related + model inputs based on a given sequence group. These additional inputs + are used to augment an already-computed `EncoderDecoderModelInput` + data structure which already has decoder-related model inputs + populated. + + Sets the following attn_metadata fields: + * `num_encoder_tokens` + * `encoder_seq_lens` + * `encoder_seq_lens_tensor` + * `max_encoder_seq_len` + * `cross_slot_mapping` + * `cross_block_tables` + + Constructs a new model inputs data structure, based on + (1) the existing fields in the `model_inputs` argument, + and (2) the following additional fields which are + computed (or in the case of `attn_metadata`, updated) + by this function: + * attn_metadata + * encoder_input_tokens + * encoder_input_positions + + Arguments: + + * seq_group_metadata_list: list of sequence groups for which to + compute inputs + * model_inputs: model inputs data structure with decoder-oriented + fields already computed. + + Return: + + * Updated model inputs data structure + """ + + if len(seq_group_metadata_list) == 0: + return (model_input.attn_metadata, None, None) + + # Since we are not supporting chunked prefill either the entire + # batch is prefill or it is decode + is_prompt = seq_group_metadata_list[0].is_prompt + + # Build encoder inputs + encoder_seq_lens: List[int] = [] + if is_prompt: + # Prefill phase. + cross_block_tables = self._empty_int32_tensor().view( + len(seq_group_metadata_list), -1) + + # Extract input tokens/positions, cross-attention slot-mapping, + # & seq len from each sequence group metadata + ( + encoder_input_tokens, + encoder_input_positions, + cross_slot_mapping, + ) = ( + [], + [], + [], + ) + for seq_group_metadata in seq_group_metadata_list: + # Build seq lens + seq_len = seq_group_metadata.encoder_seq_data.get_len() + token_ids = seq_group_metadata.encoder_seq_data.get_token_ids() + encoder_seq_lens.append(seq_len) + + # Build slot mapping + is_profile_run = (seq_group_metadata.block_tables is None) + if is_profile_run: + # During memory profiling, the block tables are not + # initialized yet. In this case, we just use a dummy + # slot mapping. + # In embeddings, the block tables are {seq_id: None}. + cross_slot_mapping.extend([PAD_SLOT_ID] * seq_len) + else: + for i in range(0, seq_len): + block_number = seq_group_metadata.cross_block_table[ + i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + cross_slot_mapping.append(slot) + + # Build encoder input tokens + encoder_input_tokens.extend(token_ids) + encoder_input_positions.extend(list(range(0, seq_len))) + + # Convert tokens/positions & cross-attention + # slot-mapping to encoder input tensors + encoder_input_tokens_tensor = self._list_to_long_tensor( + encoder_input_tokens) + encoder_input_positions_tensor = self._list_to_long_tensor( + encoder_input_positions) + cross_slot_mapping_tensor = self._list_to_long_tensor( + cross_slot_mapping) + + else: + # Decode phase. + encoder_input_tokens_tensor = self._empty_long_tensor() + encoder_input_positions_tensor = self._empty_long_tensor() + cross_slot_mapping_tensor = self._empty_long_tensor() + # Extract cross-attention block tables & + # seq len from each sequence group metadata. + # Cross-attention block tables are empty + # during vLLM memory profiling. + cross_block_tables = [] + for seq_group_metadata in seq_group_metadata_list: + for _ in range(len(seq_group_metadata.seq_data)): + encoder_seq_lens.append( + seq_group_metadata.encoder_seq_data.get_len()) + cross_block_table = seq_group_metadata.cross_block_table + cross_block_tables.append([] if ( + cross_block_table is None) else cross_block_table) + + if (model_input.attn_metadata is not None + and model_input.attn_metadata.use_cuda_graph): + # We will be using CUDA graph replay for this decode. + max_len_of_block_table = self.get_max_block_per_batch() + batch_size = len(encoder_seq_lens) + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + cuda_graph_pad_size = graph_batch_size - batch_size + # extend the cross_block_tables and encoder_seq_lens to match + # the graph_batch_size. + cross_block_tables.extend([[] + for _ in range(cuda_graph_pad_size) + ]) + encoder_seq_lens.extend( + itertools.repeat(1, cuda_graph_pad_size)) + + else: + max_len_of_block_table = max( + len(block_table) for block_table in cross_block_tables) + + cross_block_tables = make_tensor_with_pad( + cross_block_tables, + max_len=max_len_of_block_table, + pad=0, + dtype=torch.int32, + device=self.device, + ) + + # Compute encoder sequence lengths & encoder + # sequence starting offset tensors + max_encoder_seq_len = max(encoder_seq_lens, default=0) + encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens) + encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + + 1, + dtype=torch.int32, + device=self.device) + torch.cumsum(encoder_seq_lens_tensor, + dim=0, + dtype=encoder_seq_start_loc.dtype, + out=encoder_seq_start_loc[1:]) + + # Update attention metadata with encoder-oriented attributes + attn_metadata = model_input.attn_metadata + assert attn_metadata is not None + ( + attn_metadata.num_encoder_tokens, + attn_metadata.encoder_seq_lens, + attn_metadata.encoder_seq_lens_tensor, + attn_metadata.max_encoder_seq_len, + attn_metadata.cross_slot_mapping, + attn_metadata.cross_block_tables, + attn_metadata.encoder_seq_start_loc, + ) = ( + sum(encoder_seq_lens), + encoder_seq_lens, + encoder_seq_lens_tensor, + max_encoder_seq_len, + cross_slot_mapping_tensor, + cross_block_tables, + encoder_seq_start_loc, + ) + + return (attn_metadata, encoder_input_tokens_tensor, + encoder_input_positions_tensor) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py new file mode 100644 index 0000000..ee4442f --- /dev/null +++ b/vllm/worker/model_runner.py @@ -0,0 +1,1932 @@ +import dataclasses +import gc +import inspect +import itertools +import time +import warnings +import weakref +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn + +import vllm.envs as envs +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState +from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.levels import CompilationLevel +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.core.scheduler import SchedulerOutputs +from vllm.distributed import get_pp_group +from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor import SamplingMetadata, SamplingMetadataCache +from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.model_executor.models.utils import set_cpu_offload_max_bytes +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs, MultiModalRegistry) +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import (DeviceMemoryProfiler, PyObjectCache, async_tensor_h2d, + flatten_2d_lists, is_hip, is_pin_memory_available, + supports_dynamo) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict, dump_input_when_exception) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +LORA_WARMUP_RANK = 8 +_BATCH_SIZE_ALIGNMENT = 8 +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. +# NOTE: _get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) +] +_NUM_WARMUP_ITERS = 2 + +TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") + +# For now, bump up cache limits for recompilations during CUDA graph warmups. +# torch._dynamo.config.cache_size_limit = 128 +# torch._dynamo.config.accumulated_cache_size_limit = 128 + + +@dataclass(frozen=True) +class ModelInputForGPU(ModelRunnerInputBase): + """ + This base class contains metadata needed for the base model forward pass + but not metadata for possible additional steps, e.g., sampling. Model + runners that run additional steps should subclass this method to add + additional fields. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + lora_mapping: Optional["LoRAMapping"] = None + lora_requests: Optional[Set[LoRARequest]] = None + attn_metadata: Optional["AttentionMetadata"] = None + prompt_adapter_mapping: Optional[PromptAdapterMapping] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None + finished_requests_ids: Optional[List[str]] = None + virtual_engine: int = 0 + async_callback: Optional[Callable] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForGPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForGPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + # Used for speculative decoding. We do not broadcast it because it is only + # used by the driver worker. + is_prompt: Optional[bool] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + "lora_requests": self.lora_requests, + "lora_mapping": self.lora_mapping, + "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, + "virtual_engine": self.virtual_engine, + "request_ids_to_seq_ids": self.request_ids_to_seq_ids, + "finished_requests_ids": self.finished_requests_ids, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForGPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]): + """Build ModelInputForGPU from SequenceGroupMetadata.""" + + # Note: ideally we would be using a dataclass(kw_only=True) + # here, so that this can be subclassed easily, + # but kw_only is not supported in python<3.10. + class InterDataForSeqGroup: + """Intermediate data for the current sequence group.""" + + def simple_reinit(self): + self.input_tokens[0].clear() # type: ignore + self.input_positions[0].clear() # type: ignore + self.mrope_input_positions = None # type: ignore + self.seq_lens[0] = 0 # type: ignore + self.orig_seq_lens[0] = 0 # type: ignore + self.query_lens[0] = 0 # type: ignore + self.context_lens[0] = 0 # type: ignore + self.curr_sliding_window_blocks[0] = 0 # type: ignore + self.lora_index_mapping.clear() # type: ignore + self.lora_prompt_mapping.clear() # type: ignore + self.lora_requests.clear() # type: ignore + self.prompt_adapter_index_mapping.clear() # type: ignore + self.prompt_adapter_prompt_mapping.clear() # type: ignore + + def __init__( + self, + *, + # From sequence group metadata. + request_id: str, + seq_ids: List[int], + is_prompt: bool, + block_tables: Optional[Dict[int, List[int]]], + computed_block_nums: List[int], + n_seqs: int = 0, + + # Input tokens and positions. + input_tokens: Optional[List[List[int]]] = None, + input_positions: Optional[List[List[int]]] = None, + mrope_input_positions: Optional[List[List[List[int]]]] = None, + + # The sequence length (may be capped to the sliding window). + seq_lens: Optional[List[int]] = None, + # The original sequence length (before applying sliding window). + # This is used to compute slot mapping. + orig_seq_lens: Optional[List[int]] = None, + # The query length. + query_lens: Optional[List[int]] = None, + # The number of tokens that are already computed. + context_lens: Optional[List[int]] = None, + # The current sliding window block. + curr_sliding_window_blocks: Optional[List[int]] = None, + + # LoRA inputs. + lora_index_mapping: Optional[List[List[int]]] = None, + lora_prompt_mapping: Optional[List[List[int]]] = None, + lora_requests: Optional[Set[LoRARequest]] = None, + + # Prompt adapter inputs. + prompt_adapter_index_mapping: Optional[List[int]] = None, + prompt_adapter_prompt_mapping: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + + # Multi-modal inputs. + multi_modal_inputs: Optional[MultiModalInputs] = None, + + # Whether the prefix cache is hit (prefill only). + prefix_cache_hit: bool = False, + reinit: bool = False, + reinit_use_defaults: bool = False, + encoder_seq_len: int = 0, + ): + if reinit: + assert len(self.seq_ids) == len(seq_ids) # type: ignore + for i, seq_id in enumerate(seq_ids): + self.seq_ids[i] = seq_id # type: ignore + else: + self.seq_ids = seq_ids + + self.request_id = request_id + self.is_prompt = is_prompt + self.block_tables = block_tables + self.computed_block_nums = computed_block_nums + self.n_seqs = n_seqs + self.encoder_seq_len = encoder_seq_len + + if reinit: + if len(self.seq_ids) == 1 and reinit_use_defaults: + self.simple_reinit() + else: + if input_tokens: + self.input_tokens = input_tokens + else: + for seq_id in range(len(self.seq_ids)): + self.input_tokens[seq_id].clear() + + if input_positions: + self.input_positions = input_positions + else: + for seq_id in range(len(self.seq_ids)): + self.input_positions[seq_id].clear() + + self.mrope_input_positions = None + + if seq_lens: + self.seq_lens = seq_lens + else: + for seq_id in range(len(self.seq_ids)): + self.seq_lens[seq_id] = 0 + + if orig_seq_lens: + self.orig_seq_lens = orig_seq_lens + else: + for seq_id in range(len(self.seq_ids)): + self.orig_seq_lens[seq_id] = 0 + + if query_lens: + self.query_lens = query_lens + else: + for seq_id in range(len(self.seq_ids)): + self.query_lens[seq_id] = 0 + + if context_lens: + self.context_lens = context_lens + else: + for seq_id in range(len(self.seq_ids)): + self.context_lens[seq_id] = 0 + + if curr_sliding_window_blocks: + self.curr_sliding_window_blocks = \ + curr_sliding_window_blocks + else: + for seq_id in range(len(self.seq_ids)): + self.curr_sliding_window_blocks[seq_id] = 0 + + if lora_index_mapping: + self.lora_index_mapping = lora_index_mapping + else: + self.lora_index_mapping.clear() + + if lora_prompt_mapping: + self.lora_prompt_mapping = lora_prompt_mapping + else: + self.lora_prompt_mapping.clear() + + if lora_requests: + self.lora_requests = lora_requests + else: + self.lora_requests.clear() + + if prompt_adapter_index_mapping: + self.prompt_adapter_index_mapping = \ + prompt_adapter_index_mapping + else: + self.prompt_adapter_index_mapping.clear() + + if prompt_adapter_prompt_mapping: + self.prompt_adapter_prompt_mapping = \ + prompt_adapter_prompt_mapping + else: + self.prompt_adapter_prompt_mapping.clear() + + else: + self.input_tokens = input_tokens or [] + self.input_positions = input_positions or [] + self.mrope_input_positions = mrope_input_positions or None + self.seq_lens = seq_lens or [] + self.orig_seq_lens = orig_seq_lens or [] + self.query_lens = query_lens or [] + self.context_lens = context_lens or [] + self.curr_sliding_window_blocks = \ + curr_sliding_window_blocks or [] + + self.lora_index_mapping = lora_index_mapping or [] + self.lora_prompt_mapping = lora_prompt_mapping or [] + self.lora_requests = lora_requests or set() + + self.prompt_adapter_index_mapping = ( + prompt_adapter_index_mapping or []) + self.prompt_adapter_prompt_mapping = ( + prompt_adapter_prompt_mapping or []) + + self.prompt_adapter_request = prompt_adapter_request + self.multi_modal_inputs = multi_modal_inputs + self.prefix_cache_hit = prefix_cache_hit + + self.n_seqs = len(self.seq_ids) + + if not reinit: + self.__post_init__() + + def __post_init__(self): + self.n_seqs = len(self.seq_ids) + + self.input_tokens = [[] for _ in range(self.n_seqs)] + self.input_positions = [[] for _ in range(self.n_seqs)] + self.mrope_input_positions = None + self.seq_lens = [0] * self.n_seqs + self.orig_seq_lens = [0] * self.n_seqs + self.query_lens = [0] * self.n_seqs + self.context_lens = [0] * self.n_seqs + self.curr_sliding_window_blocks = [0] * self.n_seqs + + self.lora_index_mapping = [] + self.lora_prompt_mapping = [] + + def gen_inter_data_builder(self, num_seqs: int): + return lambda: ModelInputForGPUBuilder.InterDataForSeqGroup( + request_id="", + seq_ids=[0] * num_seqs, + is_prompt=True, + block_tables=None, + computed_block_nums=[]) + + def init_cached_inter_data(self, *args, **kwargs): + assert len(args) == 0 + assert "seq_ids" in kwargs + seq_ids = kwargs["seq_ids"] + num_seqs = len(seq_ids) + + # The inter-data cache is per model_runner + inter_data_cache = self.runner.inter_data_cache + if num_seqs not in inter_data_cache: + inter_data_cache[num_seqs] = PyObjectCache( + self.gen_inter_data_builder(num_seqs)) + + obj = inter_data_cache[num_seqs].get_object() + obj.__init__(*args, **kwargs) + return obj + + def reset_cached_inter_data(self): + for cache in self.runner.inter_data_cache.values(): + cache.reset() + + def __init__(self, + runner: "GPUModelRunnerBase", + finished_requests_ids: Optional[List[str]] = None): + super().__init__() + # Compute functions for each sequence in a sequence group. + # WARNING: The order of the functions matters! + self.per_seq_compute_fns = [ + self._compute_lens, + self._compute_for_prefix_cache_hit, + self._compute_for_sliding_window, + self._compute_lora_input, + ] + # Compute functions for each sequence group. + # WARNING: The order of the functions matters! + self.per_seq_group_compute_fns = [ + self._compute_prompt_adapter_input, + self._compute_multi_modal_input, + ] + + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.scheduler_config = self.runner.scheduler_config + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.enable_lora = self.runner.lora_config is not None + self.enable_prompt_adapter = (self.runner.prompt_adapter_config + is not None) + self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.finished_requests_ids = finished_requests_ids + self.decode_only = True + + # Intermediate data (data in CPU before going to GPU) for + # the current sequence group. + self.inter_data_list: List[ + ModelInputForGPUBuilder.InterDataForSeqGroup] = [] + + # Attention metadata inputs. + self.attn_metadata_builder = self.attn_backend.make_metadata_builder( + weakref.proxy(self)) + + # Engine/Model configurations. + self.chunked_prefill_enabled = ( + self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled) + if self.sliding_window is not None: + self.sliding_window_blocks = ( + self.sliding_window + self.block_size - 1) // self.block_size + self.block_aligned_sliding_window = \ + self.sliding_window_blocks * self.block_size + + def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Compute context length, sequence length and tokens + for the given sequence data. + """ + seq_data = seq_group_metadata.seq_data[inter_data.seq_ids[seq_idx]] + token_chunk_size = seq_group_metadata.token_chunk_size + + # Compute context length (the number of tokens that are + # already computed) and sequence length (total number of tokens). + + seq_len = seq_data.get_len() + if inter_data.is_prompt: + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + elif self.runner.scheduler_config.is_multi_step or \ + self.runner.model_config.is_encoder_decoder_model: + context_len = seq_len - 1 + else: + context_len = seq_data.get_num_computed_tokens() + + # Compute tokens. + tokens = seq_data.get_token_ids()[context_len:seq_len] + + inter_data.seq_lens[seq_idx] = seq_len + inter_data.orig_seq_lens[seq_idx] = seq_len + inter_data.context_lens[seq_idx] = context_len + inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.query_lens[seq_idx] = seq_len - context_len + + if seq_data.mrope_position_delta is not None: + if inter_data.mrope_input_positions is None: + inter_data.mrope_input_positions = [None] * inter_data.n_seqs + + inter_data.mrope_input_positions[ + seq_idx] = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + + def _compute_for_prefix_cache_hit( + self, inter_data: InterDataForSeqGroup, seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Check if hit prefix cache (i.e., some blocks are already computed). + If hit, update input tokens and positions to only compute the + remaining blocks. + """ + computed_block_nums = inter_data.computed_block_nums + + # Note that prefix caching does not support sliding window. + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and inter_data.is_prompt) + inter_data.prefix_cache_hit = prefix_cache_hit + + if not prefix_cache_hit: + return + + assert computed_block_nums is not None + # The cache hit prompt tokens in this sequence. Note that + # this may be larger than the sequence length if chunked + # prefill is enabled. + prefix_cache_len = len(computed_block_nums) * self.block_size + # The number of so far computed prompt tokens in this sequence. + context_len = inter_data.context_lens[seq_idx] + # The total number of prompt tokens in this sequence. + # When chunked prefill is enabled, this is the token number of + # computed chunks + current chunk. + seq_len = inter_data.seq_lens[seq_idx] + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + uncomputed_start = prefix_cache_len - context_len + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][uncomputed_start:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][uncomputed_start:] + context_len = prefix_cache_len + + inter_data.context_lens[seq_idx] = context_len + inter_data.query_lens[ + seq_idx] = inter_data.seq_lens[seq_idx] - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][-1:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][-1:] + inter_data.query_lens[seq_idx] = 1 + inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 + + def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """Update seq_len and curr_sliding_window_block for the given + sequence data (only required by decoding) if sliding window is enabled. + """ + curr_sliding_window_block = 0 + sliding_seq_len = inter_data.seq_lens[seq_idx] + if not inter_data.is_prompt and self.sliding_window is not None: + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + curr_sliding_window_block = self.sliding_window_blocks + if self.scheduler_config.use_v2_block_manager: + # number of elements in last block + suff_len = inter_data.seq_lens[seq_idx] % self.block_size + sliding_seq_len = min( + inter_data.seq_lens[seq_idx], + self.block_aligned_sliding_window + suff_len) + if suff_len > 0: + curr_sliding_window_block += 1 + else: + sliding_seq_len = min(inter_data.seq_lens[seq_idx], + self.sliding_window) + + inter_data.curr_sliding_window_blocks[ + seq_idx] = curr_sliding_window_block + inter_data.seq_lens[seq_idx] = sliding_seq_len + + def _compute_lora_input(self, inter_data: InterDataForSeqGroup, + seq_idx: int, + seq_group_metadata: SequenceGroupMetadata): + """If LoRA is enabled, compute LoRA index and prompt mapping.""" + if not self.enable_lora: + return + + lora_id = seq_group_metadata.lora_int_id + if lora_id > 0: + inter_data.lora_requests.add(seq_group_metadata.lora_request) + query_len = inter_data.query_lens[seq_idx] + inter_data.lora_index_mapping.append([lora_id] * query_len) + inter_data.lora_prompt_mapping.append( + [lora_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs is not None + else 1)) + + def _compute_prompt_adapter_input( + self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If prompt adapter is enabled, compute index and prompt mapping. + """ + # Note that when is_prompt=True, we expect only one sequence + # in the group. + if not self.enable_prompt_adapter: + return + + prompt_adapter_id = seq_group_metadata.prompt_adapter_id + if prompt_adapter_id <= 0 or not inter_data.is_prompt: + return + + # We expect only one sequence in the group when is_prompt=True. + assert inter_data.n_seqs == 1 + query_len = inter_data.query_lens[0] + inter_data.prompt_adapter_request = ( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.prompt_adapter_num_virtual_tokens + inter_data.prompt_adapter_index_mapping = [ + prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + inter_data.prompt_adapter_prompt_mapping = [prompt_adapter_id] * ( + query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs else 1) + + def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, + seq_group_metadata: SequenceGroupMetadata): + """If multi-modal data is given, add it to the input.""" + mm_data = seq_group_metadata.multi_modal_data + if not mm_data: + return + + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs) + inter_data.multi_modal_inputs = mm_kwargs + + # special processing for mrope position deltas. + if self.runner.model_is_mrope: + image_grid_thw = mm_kwargs.get("image_grid_thw", None) + video_grid_thw = mm_kwargs.get("video_grid_thw", None) + assert image_grid_thw is not None or video_grid_thw is not None, ( + "mrope embedding type requires multi-modal input mapper " + "returns 'image_grid_thw' or 'video_grid_thw'.") + + hf_config = self.runner.model_config.hf_config + + inter_data.mrope_input_positions = [None] * inter_data.n_seqs + for seq_idx in range(inter_data.n_seqs): + seq_data = seq_group_metadata.seq_data[ + inter_data.seq_ids[seq_idx]] + token_ids = seq_data.get_token_ids() + + mrope_input_positions, mrope_position_delta = \ + MRotaryEmbedding.get_input_positions( + token_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + image_token_id=hf_config.image_token_id, + video_token_id=hf_config.video_token_id, + vision_start_token_id=hf_config.vision_start_token_id, + vision_end_token_id=hf_config.vision_end_token_id, + spatial_merge_size=hf_config.vision_config. + spatial_merge_size, + context_len=inter_data.context_lens[seq_idx], + ) + + seq_data.mrope_position_delta = mrope_position_delta + inter_data.mrope_input_positions[ + seq_idx] = mrope_input_positions + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + """Add a sequence group to the builder.""" + seq_ids = seq_group_metadata.seq_data.keys() + n_seqs = len(seq_ids) + is_prompt = seq_group_metadata.is_prompt + + if is_prompt: + assert n_seqs == 1 + self.decode_only = False + + encoder_seq_len = 0 + + if self.runner.model_config.is_encoder_decoder_model: + encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len() + + inter_data = self.init_cached_inter_data( + request_id=seq_group_metadata.request_id, + seq_ids=seq_ids, + is_prompt=is_prompt, + block_tables=seq_group_metadata.block_tables, + computed_block_nums=seq_group_metadata.computed_block_nums, + reinit=True, + reinit_use_defaults=True, + encoder_seq_len=encoder_seq_len) + + self.inter_data_list.append(inter_data) + + for seq_idx in range(n_seqs): + for per_seq_fn in self.per_seq_compute_fns: + per_seq_fn(inter_data, seq_idx, seq_group_metadata) + for per_seq_group_fn in self.per_seq_group_compute_fns: + per_seq_group_fn(inter_data, seq_group_metadata) + + def _use_captured_graph(self, + batch_size: int, + decode_only: bool, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> bool: + return (decode_only and not self.runner.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_decode_seq_len <= self.runner.max_seq_len_to_capture + and max_encoder_seq_len <= self.runner.max_seq_len_to_capture + and batch_size <= self.runner.max_batchsize_to_capture) + + def _get_cuda_graph_pad_size(self, + num_seqs: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> int: + """ + Determine the number of padding sequences required for running in + CUDA graph mode. Returns -1 if CUDA graphs cannot be used. + + In the multi-step + chunked-prefill case, only the first step + has Prefills (if any). The rest of the steps are guaranteed to be all + decodes. In this case, we set up the padding as if all the sequences + are decodes so we may run all steps except the first step in CUDA graph + mode. The padding is accounted for in the multi-step `advance_step` + family of functions. + + Args: + num_seqs (int): Number of sequences scheduled to run. + max_decode_seq_len (int): Greatest of all the decode sequence + lengths. Used only in checking the viablility of using + CUDA graphs. + max_encoder_seq_len (int, optional): Greatest of all the encode + sequence lengths. Defaults to 0. Used only in checking the + viability of using CUDA graphs. + Returns: + int: Returns the determined number of padding sequences. If + CUDA graphs is not viable, returns -1. + """ + is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ + self.runner.scheduler_config.chunked_prefill_enabled + decode_only = self.decode_only or is_mscp + if not decode_only: + # Early exit so we can treat num_seqs as the batch_size below. + return -1 + + # batch_size out of this function refers to the number of input + # tokens being scheduled. This conflation of num_seqs as batch_size + # is valid as this is a decode-only case. + batch_size = num_seqs + if not self._use_captured_graph(batch_size, decode_only, + max_decode_seq_len, + max_encoder_seq_len): + return -1 + + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + return graph_batch_size - batch_size + + def build(self) -> ModelInputForGPU: + """Finalize the builder intermediate data and + create on-device tensors. + """ + # Combine and flatten intermediate data. + input_tokens = [] + for inter_data in self.inter_data_list: + for cur_input_tokens in inter_data.input_tokens: + input_tokens.extend(cur_input_tokens) + + if not input_tokens: + # This may happen when all prefill requests hit + # prefix caching and there is no decode request. + return self.model_input_cls() + + mrope_input_positions: Optional[List[List[int]]] = None + if any(inter_data.mrope_input_positions is not None + for inter_data in self.inter_data_list): + mrope_input_positions = [[] for _ in range(3)] + for idx in range(3): + for inter_data in self.inter_data_list: + msections = inter_data.mrope_input_positions + if msections is None: + for _seq_input_positions in inter_data.input_positions: + mrope_input_positions[idx].extend( + _seq_input_positions) + else: + for _seq_mrope_input_positions in msections: + mrope_input_positions[idx].extend( + _seq_mrope_input_positions[idx]) + input_positions = None + else: + input_positions = [] + for inter_data in self.inter_data_list: + for cur_input_positions in inter_data.input_positions: + input_positions.extend(cur_input_positions) + + seq_lens = [] + query_lens = [] + max_decode_seq_len = 0 + max_encoder_seq_len = 0 + for inter_data in self.inter_data_list: + seq_lens.extend(inter_data.seq_lens) + query_lens.extend(inter_data.query_lens) + if not inter_data.is_prompt: + max_decode_seq_len = max(max_decode_seq_len, + max(inter_data.seq_lens)) + if self.runner.model_config.is_encoder_decoder_model: + max_encoder_seq_len = max(max_encoder_seq_len, + inter_data.encoder_seq_len) + + # Mapping from request IDs to sequence IDs. Used for Jamba models + # that manages the cache by itself. + request_ids_to_seq_ids = { + data.request_id: data.seq_ids + for data in self.inter_data_list + } + + cuda_graph_pad_size = self._get_cuda_graph_pad_size( + num_seqs=len(seq_lens), + max_decode_seq_len=max_encoder_seq_len, + max_encoder_seq_len=max_encoder_seq_len) + + batch_size = len(input_tokens) + if cuda_graph_pad_size != -1: + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + batch_size += cuda_graph_pad_size + + # Tokens and positions. + if cuda_graph_pad_size: + input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) + assert self.runner.device is not None + input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, + self.runner.device, + self.runner.pin_memory) + if mrope_input_positions is not None: + for idx in range(3): + mrope_input_positions[idx].extend( + itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(mrope_input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + else: + input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) + input_positions_tensor = async_tensor_h2d(input_positions, + torch.long, + self.runner.device, + self.runner.pin_memory) + # Sequence and query lengths. + if cuda_graph_pad_size: + seq_lens.extend(itertools.repeat(1, cuda_graph_pad_size)) + + # Attention metadata. + attn_metadata = self.attn_metadata_builder.build( + seq_lens, query_lens, cuda_graph_pad_size, batch_size) + + # LoRA data. + lora_requests = set() + lora_mapping = None + if self.enable_lora: + lora_requests = set(r for data in self.inter_data_list + for r in data.lora_requests) + lora_index_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_index_mapping) + for inter_data in self.inter_data_list + ]) + if cuda_graph_pad_size: + lora_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) + lora_prompt_mapping = flatten_2d_lists([ + flatten_2d_lists(inter_data.lora_prompt_mapping) + for inter_data in self.inter_data_list + ]) + + lora_mapping = LoRAMapping( + **dict(index_mapping=lora_index_mapping, + prompt_mapping=lora_prompt_mapping, + is_prefill=not self.decode_only)) + + # Prompt adapter data. + prompt_adapter_requests: Set[PromptAdapterRequest] = set() + prompt_adapter_mapping = None + if self.enable_prompt_adapter: + prompt_adapter_requests = set( + data.prompt_adapter_request for data in self.inter_data_list + if data.prompt_adapter_request is not None) + prompt_adapter_index_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_index_mapping + for inter_data in self.inter_data_list + ]) + if cuda_graph_pad_size: + prompt_adapter_index_mapping.extend( + itertools.repeat(0, cuda_graph_pad_size)) + prompt_adapter_prompt_mapping = flatten_2d_lists([ + inter_data.prompt_adapter_prompt_mapping + for inter_data in self.inter_data_list + ]) + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + + # Multi-modal data. + multi_modal_inputs_list = [ + data.multi_modal_inputs for data in self.inter_data_list + if data.multi_modal_inputs is not None + ] + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return self.model_input_cls( + input_tokens=input_tokens_tensor, + input_positions=input_positions_tensor, + attn_metadata=attn_metadata, + seq_lens=seq_lens, + query_lens=query_lens, + lora_mapping=lora_mapping, + lora_requests=lora_requests, + multi_modal_kwargs=multi_modal_kwargs, + request_ids_to_seq_ids=request_ids_to_seq_ids, + finished_requests_ids=self.finished_requests_ids, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests) + + +class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): + """ + Helper class for shared methods between GPU model runners. + """ + _model_input_cls: Type[TModelInputForGPU] + _builder_cls: Type[ModelInputForGPUBuilder] + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config + self.return_hidden_states = return_hidden_states + self.observability_config = observability_config + + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) + + self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ + {} for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.graph_memory_pool: Optional[Tuple[ + int, int]] = None # Set during graph capture. + + self.has_inner_state = model_config.has_inner_state + + # When using CUDA graph, the input block tables must be padded to + # max_seq_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = np.zeros( + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), + dtype=np.int32) + + # Attention-free but stateful models like Mamba need a placeholder attn + # backend, as the attention metadata is needed to manage internal state. + # However we must bypass attention selection altogether for some models + # used for speculative decoding to avoid a divide-by-zero in + # model_config.get_head_size() + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) if needs_attn_backend else None + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) + + # Multi-modal data support + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization + self.model: nn.Module # Set after load_model + # Set after load_model. + self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None + + set_cpu_offload_max_bytes( + int(self.cache_config.cpu_offload_gb * 1024**3)) + + # Used to cache python objects + self.inter_data_cache: Dict[int, PyObjectCache] = {} + + # Using the PythonizationCache in Pipeline-Parallel clobbers the + # SequenceGroupToSample object. In Pipeline-Parallel, we have + # more than 1 Scheduler, resulting in a potential back-to-back + # prepare_model_inputs() call. This clobbers the cached + # SequenceGroupToSample objects, as we reset the cache during + # every prepare_model_inputs() call. + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + def load_model(self) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + if self.lora_config: + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.vocab_size, + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + self.model = self.lora_manager.create_lora_manager(self.model) + + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + + if self.kv_cache_dtype == "fp8" and is_hip(): + # Currently only ROCm accepts kv-cache scaling factors + # via quantization_param_path and this will be deprecated + # in the future. + if self.model_config.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + warnings.warn( + "Loading kv cache scaling factor from JSON is " + "deprecated and will be removed. Please include " + "kv cache scaling factors in the model checkpoint.", + FutureWarning, + stacklevel=2) + self.model.load_kv_cache_scales( + self.model_config.quantization_param_path) + logger.info("Loaded KV cache scaling factors from %s", + self.model_config.quantization_param_path) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!") + + if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \ + and supports_dynamo(): + from vllm.plugins import get_torch_compile_backend + backend = get_torch_compile_backend() or "eager" + self.model = torch.compile( + self.model, + fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, + backend=backend) + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model, + path, + pattern=pattern, + max_size=max_size, + ) + + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + from vllm.model_executor.model_loader.loader import TensorizerLoader + TensorizerLoader.save_model( + self.model, + tensorizer_config=tensorizer_config, + ) + + def get_max_block_per_batch(self) -> int: + block_size = self.block_size + return (self.max_seq_len_to_capture + block_size - 1) // block_size + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> TModelInputForGPU: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + builder.reset_cached_inter_data() + + return builder.build() # type: ignore + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + seq_data, dummy_multi_modal_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + multi_modal_data=dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(num_layers) + ] + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + + graph_batch_size = self.max_batchsize_to_capture + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + if self.model_config.enforce_eager: + batch_size_capture_list = [] + with set_compile_context(batch_size_capture_list): + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.cuda.synchronize() + return + + def remove_all_loras(self): + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.remove_all_adapters() + + def set_active_loras(self, lora_requests: Set[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_adapter(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_adapter(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.pin_adapter(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_adapters() + + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_adapters() + + @property + def model_is_mrope(self) -> bool: + """Detect if the model has "mrope" rope_scaling type. + mrope requires keep "rope_deltas" between prompt and decoding phases.""" + rope_scaling = getattr(self.model_config.hf_config, "rope_scaling", {}) + if rope_scaling is None: + return False + return rope_scaling.get("type", None) == "mrope" + + @torch.inference_mode() + def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: + """Cuda graph capture a model. + + Note that CUDA graph's performance gain is negligible if number + of batched tokens are larger than 200. And since CUDA graph + requires fixed sized tensors, supporting large/variable batch + size requires high GPU memory overhead. Thus, vLLM only captures + decoding requests. Mixed batch (chunked prefill + decoding) or + prefill requests are not captured. + + Since it is used for decoding-only, it assumes there's only 1 token + per sequence in the batch. + """ + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = self.max_batchsize_to_capture + input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() + if self.model_is_mrope: + input_positions = torch.tile(input_positions, (3, 1)) + # Prepare dummy previous_hidden_states only if needed by the model. + # This is used by draft models such as EAGLE. + previous_hidden_states = None + if "previous_hidden_states" in inspect.signature( + self.model.forward).parameters: + previous_hidden_states = torch.empty( + [max_batch_size, + self.model_config.get_hidden_size()], + dtype=self.model_config.dtype, + device=self.device) + + intermediate_inputs = None + if not get_pp_group().is_first_rank: + intermediate_inputs = self.model.make_empty_intermediate_tensors( + batch_size=max_batch_size, + dtype=self.model_config.dtype, + device=self.device) + + # Prepare buffer for outputs. These will be reused for all batch sizes. + # It will be filled after the first graph capture. + hidden_or_intermediate_states: List[Optional[torch.Tensor]] = [ + None + ] * self.parallel_config.pipeline_parallel_size + + graph_batch_size = self.max_batchsize_to_capture + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + + with self.attn_state.graph_capture( + max_batch_size), graph_capture() as graph_capture_context: + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for virtual_engine in range( + self.parallel_config.pipeline_parallel_size): + for batch_size in reversed(batch_size_capture_list): + attn_metadata = ( + self.attn_state.graph_capture_get_metadata_for_batch( + batch_size, + is_encoder_decoder_model=self.model_config. + is_encoder_decoder_model)) + + if self.lora_config: + lora_mapping = LoRAMapping( + **dict(index_mapping=[0] * batch_size, + prompt_mapping=[0] * batch_size, + is_prefill=False)) + self.set_active_loras(set(), lora_mapping) + + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters( + set(), prompt_adapter_mapping) + graph_runner = CUDAGraphRunner( + self.model, self.attn_backend.get_name(), + self.attn_state.graph_clone(batch_size), + self.model_config.is_encoder_decoder_model) + + capture_inputs = { + "input_ids": + input_tokens[:batch_size], + "positions": + input_positions[..., :batch_size], + "hidden_or_intermediate_states": + hidden_or_intermediate_states[ + virtual_engine] # type: ignore + [:batch_size] + if hidden_or_intermediate_states[virtual_engine] + is not None else None, + "intermediate_inputs": + intermediate_inputs[:batch_size] + if intermediate_inputs is not None else None, + "kv_caches": + kv_caches[virtual_engine], + "attn_metadata": + attn_metadata, + "memory_pool": + self.graph_memory_pool, + "stream": + graph_capture_context.stream + } + if previous_hidden_states is not None: + capture_inputs[ + "previous_hidden_states"] = previous_hidden_states[: + batch_size] + + if self.has_inner_state: + # Only used by Mamba-based models CUDA graph atm (Jamba) + capture_inputs.update({ + "seqlen_agnostic_capture_inputs": + self.model.get_seqlen_agnostic_capture_inputs( + batch_size) + }) + if self.model_config.is_encoder_decoder_model: + # add the additional inputs to capture for + # encoder-decoder models. + self._update_inputs_to_capture_for_enc_dec_model( + capture_inputs) + + with set_forward_context(attn_metadata): + graph_runner.capture(**capture_inputs) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[virtual_engine][batch_size] = ( + graph_runner) + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info("Graph capturing finished in %.0f secs.", elapsed_time) + + def _update_inputs_to_capture_for_enc_dec_model(self, + capture_inputs: Dict[str, + Any]): + """ + Updates the set of input tensors needed for CUDA graph capture in an + encoder-decoder model. + + This method modifies the provided `capture_inputs` dictionary by + adding tensors specific to encoder-decoder specific models that + need to be captured for CUDA Graph replay. + """ + # During the decode phase encoder_input_ids and encoder_positions are + # unset. Do the same thing for graph capture. + capture_inputs["encoder_input_ids"] = torch.tensor( + [], dtype=torch.long).cuda() + capture_inputs["encoder_positions"] = torch.tensor( + [], dtype=torch.long).cuda() + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + +class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]): + """ + GPU model runner with sampling step. + """ + _model_input_cls: Type[ModelInputForGPUWithSamplingMetadata] = ( + ModelInputForGPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForGPUBuilder] = ModelInputForGPUBuilder + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> ModelInputForGPUWithSamplingMetadata: + model_input = \ + ModelInputForGPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForGPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + + If cuda graph is required, this API automatically pads inputs. + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + if get_pp_group().is_last_rank: + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, model_input.seq_lens, + model_input.query_lens, self.device, self.pin_memory, + generators, self.sampling_metadata_cache) + else: + sampling_metadata = None + is_prompt = (seq_group_metadata_list[0].is_prompt + if seq_group_metadata_list else None) + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + is_prompt=is_prompt, + virtual_engine=virtual_engine) + + @torch.inference_mode() + # @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + self.attn_state.begin_forward(model_input) + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_inner_state else {} + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start = torch.cuda.Event(enable_timing=True) + model_forward_end = torch.cuda.Event(enable_timing=True) + model_forward_start.record() + + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.record() + + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + if (self.is_driver_worker + and hidden_or_intermediate_states is not None + and isinstance(hidden_or_intermediate_states, + IntermediateTensors) + and self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + hidden_or_intermediate_states.tensors["model_forward_time"] = ( + torch.tensor(model_forward_time + orig_model_forward_time)) + return hidden_or_intermediate_states + + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_end.synchronize() + model_forward_time = model_forward_start.elapsed_time( + model_forward_end) + orig_model_forward_time = 0.0 + if intermediate_tensors is not None: + orig_model_forward_time = intermediate_tensors.tensors.get( + "model_forward_time", torch.tensor(0.0)).item() + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = (orig_model_forward_time + + model_forward_time) + + if self.return_hidden_states: + # we only need to pass hidden states of most recent token + assert model_input.sampling_metadata is not None + indices = model_input.sampling_metadata.selected_token_indices + if model_input.is_prompt: + hidden_states = hidden_or_intermediate_states.index_select( + 0, indices) + output.prefill_hidden_states = hidden_or_intermediate_states + elif decode_meta.use_cuda_graph: + hidden_states = hidden_or_intermediate_states[:len(indices)] + else: + hidden_states = hidden_or_intermediate_states + + output.hidden_states = hidden_states + + return [output] + + +class CUDAGraphRunner: + + def __init__(self, model: nn.Module, backend_name: str, + attn_state: AttentionState, is_encoder_decoder_model: bool): + self.model = model + self.backend_name = backend_name + self.attn_state = attn_state + + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + self._graph: Optional[torch.cuda.CUDAGraph] = None + self._is_encoder_decoder_model = is_encoder_decoder_model + + @property + def graph(self): + assert self._graph is not None + return self._graph + + def capture( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_or_intermediate_states: Optional[Union[IntermediateTensors, + torch.Tensor]], + intermediate_inputs: Optional[IntermediateTensors], + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + memory_pool: Optional[Tuple[int, int]], + stream: torch.cuda.Stream, + **kwargs, + ) -> Union[torch.Tensor, IntermediateTensors]: + assert self._graph is None + # Run the model a few times without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + # Note one iteration is not enough for torch.jit.script + for _ in range(_NUM_WARMUP_ITERS): + self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, + **kwargs, + ) + # Wait for the warm up operations to finish before proceeding with + # Graph Capture. + torch.cuda.synchronize() + # Capture the graph. + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream): + output_hidden_or_intermediate_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_inputs, + **kwargs, + ) + if hidden_or_intermediate_states is not None: + if get_pp_group().is_last_rank: + hidden_or_intermediate_states.copy_( + output_hidden_or_intermediate_states) + else: + for key in hidden_or_intermediate_states.tensors: + hidden_or_intermediate_states[key].copy_( + output_hidden_or_intermediate_states[key]) + else: + hidden_or_intermediate_states = ( + output_hidden_or_intermediate_states) + + del output_hidden_or_intermediate_states + # make sure `output_hidden_states` is deleted + # in the graph's memory pool + gc.collect() + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "input_ids": + input_ids, + "positions": + positions, + "kv_caches": + kv_caches, + **self.attn_state.get_graph_input_buffers( + attn_metadata, self._is_encoder_decoder_model), + **kwargs, + } + if intermediate_inputs is not None: + self.input_buffers.update(intermediate_inputs.tensors) + if get_pp_group().is_last_rank: + self.output_buffers = { + "hidden_states": hidden_or_intermediate_states + } + else: + self.output_buffers = hidden_or_intermediate_states + return hidden_or_intermediate_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + **kwargs, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + + if self.backend_name != "placeholder-attn": + self.input_buffers["slot_mapping"].copy_( + attn_metadata.slot_mapping, non_blocking=True) + + self.attn_state.prepare_graph_input_buffers( + self.input_buffers, attn_metadata, self._is_encoder_decoder_model) + + if "seqlen_agnostic_capture_inputs" in self.input_buffers: + self.model.copy_inputs_before_cuda_graphs(self.input_buffers, + **kwargs) + + if "previous_hidden_states" in self.input_buffers: + self.input_buffers["previous_hidden_states"].copy_( + kwargs["previous_hidden_states"], non_blocking=True) + + if intermediate_tensors is not None: + for key in intermediate_tensors.tensors: + if key != "model_execute_time" and key != "model_forward_time": + self.input_buffers[key].copy_(intermediate_tensors[key], + non_blocking=True) + if self._is_encoder_decoder_model: + self.input_buffers["encoder_input_ids"].copy_( + kwargs['encoder_input_ids'], non_blocking=True) + self.input_buffers["encoder_positions"].copy_( + kwargs['encoder_positions'], non_blocking=True) + + # Run the graph. + self.graph.replay() + # Return the output tensor. + if get_pp_group().is_last_rank: + return self.output_buffers["hidden_states"] + + return self.output_buffers + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling _get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. + if not, it means the padded size is larger than the largest size in + _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = _get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py new file mode 100644 index 0000000..86883cf --- /dev/null +++ b/vllm/worker/model_runner_base.py @@ -0,0 +1,275 @@ +import dataclasses +import pickle +from abc import ABC, abstractmethod +from datetime import datetime +from functools import wraps +from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List, + Optional, Type, TypeVar) + +import torch +from torch import is_tensor + +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata + +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.attention.backends.abstract import AttentionBackend + from vllm.model_executor import SamplingMetadata + +logger = init_logger(__name__) + +T = TypeVar('T', bound="BroadcastableModelInput") + + +def _add_attn_metadata_broadcastable_dict( + tensor_dict: Dict[str, Any], + attn_metadata: Optional["AttentionMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + AttentionMetadata fields. + """ + if attn_metadata is not None: + tensor_dict.update(attn_metadata.asdict_zerocopy()) + + +def _init_attn_metadata_from_tensor_dict( + attn_backend: "AttentionBackend", + tensor_dict: Dict[str, Any], +) -> Dict[str, Any]: + """ + Helper method to initialize AttentionMetadata based on an + AttentionBackend and broadcastable AttentionMetadata fields. + """ + # Extract the fields used to create AttentionMetadata. + valid_attn_kwargs = {} + for field in dataclasses.fields(attn_backend.get_metadata_cls()): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_attn_kwargs[field.name] = val + + attn_metadata = attn_backend.make_metadata(**valid_attn_kwargs) + tensor_dict["attn_metadata"] = attn_metadata + return tensor_dict + + +def _init_sampling_metadata_from_tensor_dict( # type: ignore + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper method to initialize SamplingMetadata based on broadcastable + SamplingMetadata fields. + """ + from vllm.model_executor import SamplingMetadata + + selected_token_indices = tensor_dict.pop("selected_token_indices", None) + # An empty SamplingMetadata to signal that the worker should skip + # sampling. + if selected_token_indices is not None: + tensor_dict["sampling_metadata"] = SamplingMetadata( + seq_groups=None, + selected_token_indices=selected_token_indices, + categorized_sample_indices=None, + num_prompts=0, + ) + return tensor_dict + + +def _add_sampling_metadata_broadcastable_dict( + tensor_dict: Dict[str, Any], + sampling_metadata: Optional["SamplingMetadata"]) -> None: + """ + Helper method to update tensor_dict with broadcastable + SamplingMetadata fields. + """ + if sampling_metadata is not None: + tensor_dict["selected_token_indices"] = ( + sampling_metadata.selected_token_indices) + + +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + + +def dump_input_when_exception(exclude_args: Optional[List[int]] = None, + exclude_kwargs: Optional[List[str]] = None): + + def _inner(func): + + @wraps(func) + def _wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as err: + timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") + filename = f"/tmp/err_{func.__name__}_input_{timestamp}.pkl" + logger.info("Writing input of failed execution to %s...", + filename) + with open(filename, "wb") as filep: + dumped_inputs = { + k: v + for k, v in kwargs.items() + if k not in (exclude_kwargs or []) + } + for i, arg in enumerate(args): + if i not in (exclude_args or []): + dumped_inputs[f"arg_{i}"] = arg + + # Only persist dtype and shape for kvcache tensors + # (can be way to big otherwise) + if (kv_caches := dumped_inputs.get("kv_caches")) \ + and isinstance(kv_caches, Iterable): + dumped_inputs["kv_caches"] = [(t.dtype, t.shape) + for t in kv_caches + if is_tensor(t)] + + try: + pickle.dump(dumped_inputs, filep) + except Exception as pickle_err: + logger.warning( + "Failed to pickle inputs of failed execution: %s", + str(pickle_err)) + raise type(err)(f"Error in model execution: " + f"{str(err)}") from err + + logger.info( + "Completed writing input of failed execution to %s.", + filename) + raise type(err)( + f"Error in model execution (input dumped to {filename}): " + f"{str(err)}") from err + + return _wrapper + + return _inner + + +class BroadcastableModelInput(ABC): + + @abstractmethod + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + """ + Extract broadcastable fields. Override for fields that require some + custom deserialization. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def from_broadcasted_tensor_dict( + cls: Type[T], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> T: + """ + Pop fields from the given tensor_dict and populate a new instance of + BroadcastableModelInput. + """ + raise NotImplementedError + + +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + +class ModelRunnerInputBuilderBase(ABC, Generic[T]): + """A builder to create ModelRunnerInputBase objects. + """ + + @abstractmethod + def add_seq_group(self, seq_group_metadata): + """TBA""" + raise NotImplementedError + + @abstractmethod + def build(self, *args, **kwargs) -> T: + """Build metadata with on-device tensors.""" + raise NotImplementedError + + +class ModelRunnerBase(ABC, Generic[T]): + """ + Model runner interface that abstracts a particular hardware and/or type of + model. Model execution may communicate data with model runners in other + processes, but it should not include control plane metadata communication. + + Each ModelRunnerBase subclass should define a corresponding + ModelRunnerInputBase subclass. + """ + + # Map of request_id -> generator used for seeded random sampling + generators: Dict[str, torch.Generator] = {} + + @abstractmethod + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, Any], + ) -> T: + """ + Make an instance of a ModelRunnerInputBase from the broadcasted tensor + dict. + """ + raise NotImplementedError + + @abstractmethod + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> T: + """ + Prepare the inputs to ModelRunnerBase.execute_model from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @current_platform.inference_mode() + def execute_model( + self, + model_input: T, + kv_caches: Optional[List[torch.Tensor]], + intermediate_tensors: Optional[IntermediateTensors], + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + """ + Execute the model on the given input. + """ + raise NotImplementedError + + def get_generators(self, finished_request_ids: Optional[List[str]] = None): + """ + Return dict of per-request generators used for random sampling. + """ + + # Clean up generators from completed requests + if finished_request_ids: + for request_id in finished_request_ids: + self.generators.pop(request_id, None) + + return self.generators diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 0000000..0cd0047 --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,901 @@ +import dataclasses +import functools +from dataclasses import dataclass, field +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) + +import torch + +from vllm.distributed import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, + SamplerOutput, + SamplingMetadata, get_logprobs, + get_pythonized_sample_results) +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import PyObjectCache, async_tensor_h2d +from vllm.worker.model_runner import (GPUModelRunnerBase, + ModelInputForGPUWithSamplingMetadata) +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_attn_metadata_from_tensor_dict, + _init_frozen_model_input_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "rocm-flash-attn", "flashinfer"] +MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["flash-attn"] + +def _get_supported_attention_backends(chunked_prefill_enabled: bool) \ + -> List[str]: + if chunked_prefill_enabled: + return MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS + else: + return MULTI_STEP_ATTENTION_BACKENDS + + +def seq_output_builder(): + return SequenceOutput( + 0, 0, + {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)}) + + +def completion_seq_group_output_builder(): + return CompletionSequenceGroupOutput([], None) + + +# Used by pythonization to reduce python object allocations +class PythonizationCache: + + def __init__(self): + self.cached_seq_output = PyObjectCache(seq_output_builder) + self.cached_completion_seq_group_output = PyObjectCache( + completion_seq_group_output_builder) + + def reset(self): + self.cached_seq_output.reset() + self.cached_completion_seq_group_output.reset() + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be + ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + pythonization_cache: Optional[PythonizationCache] = None + + def pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output(input_metadata, copy_stream, + pinned_sampled_token_buffer, True) + self.pythonized = True + + def maybe_pythonize(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output( + input_metadata, copy_stream, pinned_sampled_token_buffer, + False) + + def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor, + blocking: bool) -> bool: + """ + If blocking is set, will block until the forward pass for the output is + ready and pythonize the output. Upon completing Pythonization, erases + self.logprobs (note that a non-blocking call that is performed when + the sampler output is not yet ready, will not erase self.logprobs.) + """ + assert self.sampled_token_ids is not None + if not blocking and not self.sampler_output_ready_event.query(): + return False + + if blocking: + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids, self.logprobs, + self.pythonization_cache) + + # Erase the logprobs GPU-side tensor. + # Note that although _pythonize_sampler_output() runs in its + # own CUDA stream, nonetheless _pythonize_sampler_output() + # cannot return until Pythonization is complete; therefore + # we know that by the time the CPU reaches this point, + # `self.logprobs` is no longer needed. + self.logprobs = None + return True + + +@dataclass(frozen=False) +class StatefulModelInput(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + + # list of model outputs for each step, may not be all pythonized + cached_outputs: List[ModelOutput] = field(default_factory=list) + + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + base_output_proc_callback: Optional[Callable] = None + # ping-pong data structures for multi-step to wait on the previous step + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + num_single_step_prefills: int = 0 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + 'num_single_step_prefills': self.num_single_step_prefills, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "StatefulModelInput": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + # record the event for the current step so that the next step can sync + # on it. We modulo by 2 to keep the events in a circular buffer and + # support any attn backends that may be supported in the future. ie + # Flashinfer would want two DecodeWrappers to overlap the CPU and GPU. + self.step_cuda_events[self.current_step & 1] = \ + torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step & 1].record(current_stream) + + def wait_previous_step(self): + # These cuda events are an explicit synchronization to ensure that + # advance_step() (for other attn backends that may be supported in the + # future) do not clobber any data structures that is also used by any + # enqueued forwards steps. For distributed case, only a single event is + # needed, but for single GPU case, since we can let the CPU run much + # further ahead, two events allow us to overlap the advance_step with + # the previous forward (ie using two DecodeWrappers for flashinfer + # backend) + self.step_cuda_events[(self.current_step + 1) & 1].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.cached_outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + def maybe_advance_sampling_metadata(self, device: str, pin_memory: bool): + """ + sampling_metadata.selected_token_indices is constructed for the + first-step in Multi-Step. However, when chunked-prefill is enabled with + multi-step, the scheduled prompts are fully processed in the + first-step and are processed as decodes in the rest of the steps. + This function updates the sampling_metadata.selected_token_indices + to account for this conversion. + + Example: + Let 2 prompts and 2 decodes be scheduled together. Let the + num-tokens to process for the 2 prompts be 5 and 8 respectively. + + In that case, sampling_metadata.sampled_token_indices will be, + [4, 12, 13, 14] as it is constructed for the first-step in + multi-step. + However, the prompts turns to decodes after the first-step + and the num-tokens for the previously-prompt sequences will + be 1 and 1 as they are decodes now. The self.sampled_token_indices + must be updated to [0,1,2,3]. + """ + assert self.current_step == 1 and self.num_single_step_prefills > 0 + if not get_pp_group().is_last_rank: + return + + assert self.frozen_model_input is not None + assert self.frozen_model_input.sampling_metadata is not None + self.frozen_model_input.sampling_metadata.selected_token_indices = \ + async_tensor_h2d(list(range(self.num_queries)), + dtype=torch.long, + target_device=device, + pin_memory=pin_memory) + + def maybe_advance_frozen_model_input(self, device: str, pin_memory: bool): + """ + Advancing the datastructures of StatefulModelInput::frozen_model_input + is only required when prefills are scheduled with decodes to run in + multi-step. This advancement/correction is required to account for + the conversion of Prefills to Decodes after the first multi-step. + """ + if self.current_step != 1 or self.num_single_step_prefills == 0: + return + + assert self.frozen_model_input is not None + fmi = self.frozen_model_input + + # Truncate input_tokens + assert fmi.input_tokens is not None + assert fmi.input_tokens.shape[0] >= self.num_seqs + fmi_new_input_tokens: torch.Tensor = fmi.input_tokens[:self.num_seqs] + + # Update frozen_model_input::input_positons. + assert fmi.input_positions is not None + assert fmi.input_positions.shape[0] >= self.num_seqs + fmi_new_input_positions: torch.Tensor = fmi.input_positions[:self. + num_seqs] + + # Assert unsupported + assert fmi.lora_mapping is None + assert fmi.lora_requests is not None + assert len(fmi.lora_requests) == 0 + assert fmi.attn_metadata is not None + assert fmi.prompt_adapter_mapping is None + assert fmi.prompt_adapter_requests is not None + assert len(fmi.prompt_adapter_requests) == 0 + assert fmi.multi_modal_kwargs is not None + assert len(fmi.multi_modal_kwargs) == 0 + + self.frozen_model_input = dataclasses.replace( + self.frozen_model_input, + input_tokens=fmi_new_input_tokens, + input_positions=fmi_new_input_positions) + + self.maybe_advance_sampling_metadata(device, pin_memory) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Check attention backend support. + supported_attention_backends: List[str] = \ + _get_supported_attention_backends( + self.scheduler_config.chunked_prefill_enabled) + if self.attn_backend.get_name() not in supported_attention_backends: + ms_config_str: str = "Multi-Step + Chunked-Prefill" \ + if self.scheduler_config.chunked_prefill_enabled \ + else "Multi-Step" + raise ValueError( + f"{ms_config_str} not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {supported_attention_backends}.") + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + # Using the PythonizationCache in Pipeline-Parallel clobbers the + # SequenceOutput and CompletionSequenceGroupOutput object. + # When cache-reset happens at the last step of a multi-step + # execution, there may be other on-going single-step/multi-step + # executions. The current caching implementation does not check + # for this. + self.pythonization_cache = PythonizationCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + @functools.cached_property + def _copy_stream(self): + # used to copy tensors from GPU to CPU asynchronously + return torch.cuda.Stream() + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: + model_input = (StatefulModelInput.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> StatefulModelInput: + frozen_model_input: ModelInputForGPUWithSamplingMetadata = \ + self._base_model_runner.prepare_model_input( + seq_group_metadata_list, + virtual_engine, + finished_requests_ids) + + assert frozen_model_input.query_lens is not None + assert frozen_model_input.seq_lens is not None + assert frozen_model_input.attn_metadata is not None + num_queries = len(frozen_model_input.query_lens) + num_seqs = len(frozen_model_input.seq_lens) + num_single_step_prefills = frozen_model_input.attn_metadata.num_prefills + + model_input = StatefulModelInput( + frozen_model_input=frozen_model_input, + num_seqs=num_seqs, + num_queries=num_queries, + num_single_step_prefills=num_single_step_prefills) + + return model_input + + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + output_proc_callback() + + cont = True + for step_num, model_output in enumerate(model_input.cached_outputs): + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + ctx = output_proc_callback.keywords["ctx"] + ctx.append_output( + outputs=[model_output.sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=step_num == 0) + + output_proc_callback() + else: + cont = False + + if not cont: + break + + def _final_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]): + assert model_input.frozen_model_input is not None + + has_async_callback = output_proc_callback is not None + + outputs = [] + for step_num, output in enumerate(model_input.cached_outputs): + is_last_step = step_num == len(model_input.cached_outputs) - 1 + + # For non-async case: + # -- We simply add the outputs + # For async case: + # -- Invoke callback, pythonize, add to callback queue and repeat + # -- For last output, just add to callback queue + if has_async_callback: + assert output_proc_callback is not None + + # Invoke callback before pythonize (to overlap with GPU) + output_proc_callback() + + # Pythonize + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + # For non last step, add to callback queue to chain + # callbacks=>pythonize pairs (for GPU overlap) + if not is_last_step: + ctx = output_proc_callback.keywords[ # type: ignore + "ctx"] # type: ignore + ctx.append_output( + outputs=[output.sampler_output], + seq_group_metadata_list=ctx. + seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=step_num == 0) + else: + outputs.append(output.sampler_output) + else: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + + return outputs + + @torch.inference_mode() + def execute_model( + self, + model_input: StatefulModelInput, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = ( + True) + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = ( + True) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + current_stream = torch.cuda.current_stream() + if not model_input.is_first_multi_step: + # Explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use. + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs extra CPU operations on + # input metadata we may need to synchronize any CPU operations that + # might clobber enqueued forwards. (prevents CPU from running too + # far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.cached_outputs[-1].sampler_output) + + # frozen_model_input may have been updated + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + if model_input.base_output_proc_callback is None: + assert frozen_model_input is not None + model_input.base_output_proc_callback = \ + frozen_model_input.async_callback + + if frozen_model_input.async_callback is not None: + assert model_input.base_output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=model_input.base_output_proc_callback) + + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + # Update the local instance + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(current_stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(current_stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_cpu = output[ + 0].sampled_token_ids.cpu() + model_input.cached_outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False, + output[0].logprobs, self.pythonization_cache)) + + # These GPU tensors are not required by multi-step; + # erase them to ensure they are not pythonized or + # transferred to CPU + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + + # Pythonize the output if CPU is ahead and the previous step is + # ready. + if frozen_model_input.async_callback is None: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = self._final_process_outputs( + model_input, model_input.base_output_proc_callback) + if self.pythonization_cache: + self.pythonization_cache.reset() + return outputs + + # should be [SamplerOutput] + return output + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step(self, model_input: StatefulModelInput, + out: SamplerOutput) -> StatefulModelInput: + + model_input.maybe_advance_frozen_model_input(self.device, + self.pin_memory) + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.input_tokens is not None + assert frozen_model_input.input_tokens.shape[0] == model_input.num_seqs + assert frozen_model_input.attn_metadata is not None + + sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + attn_metadata = frozen_model_input.attn_metadata + assert attn_metadata is not None + + turn_prefills_into_decodes: bool = model_input.current_step == 1 and \ + model_input.num_single_step_prefills != 0 + attn_metadata.advance_step( + frozen_model_input, + sampled_token_ids, + self.block_size, + num_seqs, + num_queries, + turn_prefills_into_decodes=turn_prefills_into_decodes) + + return model_input + + def load_model(self) -> None: + return self._base_model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], + Optional[List[SampleLogprobs]]] + + +def deferred_pythonize_logprobs( + output: SamplerOutput, + sampling_metadata: SamplingMetadata, + logprobs_tensor: Optional[torch.Tensor], +) -> DeferredLogprobsReturnType: + """Perform deferred logprob Pythonization. + + 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. + 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, + utilizing the Pythonized sampler result computed in step 1. + + These deferred computations are not required for single-step scheduling + or the `profile_run()` phase of multi-step scheduling. + + Args: + output: sampler output (under deferred Pythonization) + sampling_metadata + + Returns: + prompt_logprobs (CPU), sample_logprobs (CPU) + """ + + # - Deferred pythonization of sample result + sampler_result = get_pythonized_sample_results( + output.deferred_sample_results_args) + + # - Erase the GPU-side deferred sample_result + # computation args to ensure it is never + # pythonized or transferred to CPU + output.deferred_sample_results_args = None + + # - Deferred pythonization of logprobs + ( + prompt_logprobs, + sample_logprobs, + ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) + assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) + assert len(sample_logprobs) == len(sampling_metadata.seq_groups) + + return prompt_logprobs, sample_logprobs + + +def _pythonize_sampler_output( + model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor, + logprobs_tensor: Optional[torch.Tensor], + cache: Optional[PythonizationCache], +) -> None: + """ This function is only called when the output tensors are ready. + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + adding a Pythonized output data structure + (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. + + Args: + model_input + output: sampler output + pinned_sampled_token_token_buffer: CPU-side pinned memory + (receives copy of + GPU-side token buffer.) + sampled_token_ids: GPU-side token buffer + logprobs_tensor: GPU-side tensor containing + logprobs computed during sampling + """ + + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + sampling_metadata = frozen_model_input.sampling_metadata + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # We guarantee output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However we should check whether logprobs pythonization may + # be skipped entirely, i.e. because no logprobs were requested + # or pythonization was not deferred. To that end, + # + # * `prompt_logprobs_are_requested_for_prefill` signals that + # there are *any* prefill-phase requests which specify that + # prompt logprobs should be returned. + # + # * `any_logprobs_are_requested` signals that there are any + # requests which (1) specify that sample logprobs should be + # returned, or (2) are in the prefill phase AND specify that + # prompt logprobs should be returned. + # + # Later on, these flags cause adjustments to the pythonization + # process to accommodate logprobs. + + seq_groups = sampling_metadata.seq_groups + prompt_logprobs_are_requested_for_prefill = any([ + sg.sampling_params.prompt_logprobs is not None and sg.is_prompt + for sg in seq_groups + ]) + any_logprobs_are_requested = ( + prompt_logprobs_are_requested_for_prefill + or any([sg.sampling_params.logprobs is not None for sg in seq_groups])) + + if prompt_logprobs_are_requested_for_prefill: + # CPU GPU sync, after gathering *only* sampled tokens (since + # requesting prompt logprobs leads `sampled_token_ids` to + # include prompt token ids in addition to sampled token ids.) + sample_idx_tensor = torch.tensor( + [sdx for sg in seq_groups for sdx in sg.sample_indices]) + pinned_buffer = pinned_buffer.copy_( + sampled_token_ids[sample_idx_tensor, :], non_blocking=False) + else: + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, + non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + skip_sampler_cpu_output = ( + frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + + # *Don't* skip logprobs pythonization *if*: + # * Any requests require logprobs to be returned in this + # iteration AND + # * These requests are being scheduled in a fashion which + # defers pythonization (i.e. multi-step scheduling.) + do_pythonize_logprobs = (skip_sampler_cpu_output + and any_logprobs_are_requested) + ( + prompt_logprobs, + sample_logprobs, + ) = (deferred_pythonize_logprobs(output, sampling_metadata, + logprobs_tensor) + if do_pythonize_logprobs else (None, None)) + + for sgdx, (seq_group, + sample_result) in enumerate(zip(seq_groups, samples_list)): + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + # (Check for Guided Decoding) + if seq_group.sampling_params.logits_processors: + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + + if do_pythonize_logprobs: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( # Utilize deferred pythonization results + prompt_logprobs[sgdx], + sample_logprobs[sgdx], + ) + elif any_logprobs_are_requested: + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( + # profile_run: use already-computed logprobs + output.outputs[sgdx].prompt_logprobs, + [sample.logprobs for sample in output.outputs[sgdx].samples]) + + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + + if cache is not None: + completion_seq_group_output: CompletionSequenceGroupOutput = \ + cache.cached_completion_seq_group_output.get_object() + completion_seq_group_output.samples.clear() + seq_outputs: List[ + SequenceOutput] = completion_seq_group_output.samples + else: + seq_outputs = [] + + for tdx, (parent_id, + next_token_id) in enumerate(zip(parent_ids, next_token_ids)): + if cache is not None: + seq_output: SequenceOutput = cache.cached_seq_output.get_object( + ) + seq_output.parent_seq_id = seq_ids[parent_id] + seq_output.output_token = next_token_id + + if any_logprobs_are_requested: + seq_output.logprobs = group_sample_logprobs[tdx] + else: + logprobs = next(iter(seq_output.logprobs.values())) + seq_output.logprobs.clear() + + logprobs.logprob = float('inf') + logprobs.rank = None + logprobs.decoded_token = None + + seq_output.logprobs[next_token_id] = logprobs + + seq_outputs.append(seq_output) + + else: + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + (group_sample_logprobs[tdx] + if any_logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + if cache is not None: + completion_seq_group_output.prompt_logprobs = \ + group_prompt_logprobs if any_logprobs_are_requested else None + output.outputs.append(completion_seq_group_output) + else: + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, (group_prompt_logprobs + if any_logprobs_are_requested else None))) + + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_tpu_worker.py b/vllm/worker/multi_step_tpu_worker.py new file mode 100644 index 0000000..e654f71 --- /dev/null +++ b/vllm/worker/multi_step_tpu_worker.py @@ -0,0 +1,105 @@ +import dataclasses +from typing import Dict, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict +from vllm.sequence import ExecuteModelRequest +from vllm.worker.tpu_model_runner import ModelInputForTPU +from vllm.worker.tpu_worker import TPUWorker +from vllm.worker.worker_base import WorkerInput + + +class MultiStepTPUWorker(TPUWorker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.cached_model_input: Optional[ModelInputForTPU] = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]: + assert self.is_driver_worker + assert execute_model_req.virtual_engine == 0 + + is_first_multi_step = execute_model_req.is_first_multi_step + is_last_step = execute_model_req.is_last_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + worker_input = dataclasses.replace( + worker_input, + num_steps=execute_model_req.num_lookahead_slots + 1) + model_input: ModelInputForTPU = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( + model_input, + async_callback=execute_model_req.async_callback) + else: + assert self.cached_model_input is not None + model_input = self.cached_model_input + worker_input = WorkerInput() + model_input = dataclasses.replace( + model_input, + is_first_multi_step=is_first_multi_step, + is_last_step=is_last_step) + + if self.do_metadata_broadcast: + if is_first_multi_step: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update( + model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + else: + broadcast_data = { + "is_first_multi_step": is_first_multi_step, + "is_last_step": is_last_step, + } + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str, + torch.Tensor]]]: + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + broadcast_tensor_dict({}, src=0) + return None + + model_input, worker_input, _ = self._get_driver_input_and_broadcast( + execute_model_req) + if model_input.is_first_multi_step: + self.cached_model_input = model_input + return model_input, worker_input, {} + else: + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + if len(broadcast_data) == 2: + assert self.cached_model_input is not None + self.cached_model_input = dataclasses.replace( + self.cached_model_input, + is_first_multi_step=broadcast_data["is_first_multi_step"], + is_last_step=broadcast_data["is_last_step"]) + empty_worker_input = WorkerInput() + return self.cached_model_input, empty_worker_input, {} + + worker_input = WorkerInput.from_broadcasted_tensor_dict( + broadcast_data) + model_input = ( + self.model_runner. + make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + self.cached_model_input = model_input + return model_input, worker_input, {} diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000..bf66f32 --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,202 @@ +import dataclasses +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, + StatefulModelInput) +from vllm.worker.worker import Worker, WorkerInput + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: StatefulModelInput + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + base_model_runner = self.model_runner + # for multi-step model, wrap the model runner with MultiStepModelRunner + self.model_runner = MultiStepModelRunner( + base_model_runner, + base_model_runner.model_config, + base_model_runner.parallel_config, + base_model_runner.scheduler_config, + base_model_runner.device_config, + base_model_runner.cache_config, + load_config=base_model_runner.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=base_model_runner.is_driver_worker, + prompt_adapter_config=base_model_runner.prompt_adapter_config, + observability_config=base_model_runner.observability_config, + ) + + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + # on first step we prepare the worker input and model input normally + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: StatefulModelInput = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback) + else: + # on subsequent steps we reuse the worker input and model input + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + # clear the cached metadata so that it can be recomputed on + # the workers. + frozen_model_input.attn_metadata._cached_prefill_metadata = None + frozen_model_input.attn_metadata._cached_decode_metadata = None + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + if not is_first_multi_step: + # we broadcast the last sampled token ids to all TP workers so they + # can update their model input metadata in-place. + self._prepare_last_sampled_token_ids_for_tp_workers( + execute_model_req=execute_model_req, model_input=model_input) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + # Retuning empty dict here to keep this compatible with + # `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast` + return model_input, worker_input, {} + + def _prepare_last_sampled_token_ids_for_tp_workers( + self, + execute_model_req: ExecuteModelRequest, + model_input: StatefulModelInput, + ) -> None: + """ + Prepare the last sampled token ids for TP workers. If it's the last + PP rank, then the last sampled token ids are already in the model_input. + If it is NOT the last PP rank, then we need to get the last sampled + token that is cached in the execute_model_req. + """ + if get_pp_group().is_last_rank: + assert model_input.cached_outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.cached_outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.cached_outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = ( + execute_model_req.last_sampled_token_ids.cuda()) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.cached_outputs[:-1]: + output.sampled_token_ids = None + assert model_input.cached_outputs[-1].sampled_token_ids is not None + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str, + torch.Tensor]]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + (model_input, worker_input, + kwargs) = self._get_driver_input_and_broadcast(execute_model_req) + assert isinstance(model_input, StatefulModelInput) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input, kwargs = broadcast_data + assert isinstance(model_input, StatefulModelInput) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # TODO(will) Can cache the worker input and model input for the + # next steps. See below for details + else: + # TODO(will) possible to also cache and reuse the cached worker + # input and model input. The idea is essentially the delta + # optimization for model_inputs. Where the TP workers can cache + # the model input states and we only broadcast the delta need + # for the next step (sampled_token_ids from the previous step) + + assert isinstance(model_input, StatefulModelInput) + # we need to update the last sampled token ids in the model + # input for the workers so that they can run inplace + # advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input, kwargs diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py new file mode 100644 index 0000000..b8c760c --- /dev/null +++ b/vllm/worker/neuron_model_runner.py @@ -0,0 +1,346 @@ +import os +from dataclasses import dataclass +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers_neuronx.config import GenerationConfig + +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.neuron import get_neuron_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available, make_tensor_with_pad +from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass(frozen=True) +class ModelInputForNeuron(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + input_block_ids: Optional[torch.Tensor] = None + sampling_metadata: Optional["SamplingMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + raise NotImplementedError("ModelInputForNeuron cannot be broadcast.") + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForNeuron": + assert attn_backend is None + return cls.from_broadcasted_tensor_dict(tensor_dict) + + +class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]): + + # NEURON has an upper limit on the top_k + _MAX_NEURON_SAMPLING_TOP_K = 256 + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + + if model_config is not None and model_config.get_sliding_window(): + logger.warning("Sliding window is not supported on Neuron. " + "The model will run without sliding window.") + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + + # Lazy initialization. + self.model: nn.Module # initialize after load_model. + + # Once NEURON_ON_DEVICE_SAMPLING_DISABLED is set to a non-zero value, + # turn off on-device sampling. + self._on_device_sampling_disabled = int( + os.getenv("NEURON_ON_DEVICE_SAMPLING_DISABLED", "0")) + + # NEURON needs to update sampling parameters when request IDs change + # across batches. This variable stores the previous batch's request IDs + # to determine if an update is needed. + self._previous_batch_request_ids: List[str] = [] + + if not self._on_device_sampling_disabled: + logger.warning( + "On-device sampling is turned on in Neuron by default, only " + "top_k, top_p, and temperature are current supported sampling " + "parameters. To turn off the on-device sampling, please set " + "the environment variable NEURON_ON_DEVICE_SAMPLING_DISABLED=1." + ) + self.model_config.neuron_sampling_params = GenerationConfig( + max_length=self.scheduler_config.max_model_len, + do_sample=True, + per_batch_line=True, + top_k=[self._MAX_NEURON_SAMPLING_TOP_K] \ + * self.scheduler_config.max_num_seqs, + top_p=[1.0] * self.scheduler_config.max_num_seqs, + temperature=[1.0] * self.scheduler_config.max_num_seqs, + dynamic=True, + global_top_k=self._MAX_NEURON_SAMPLING_TOP_K) + + def load_model(self) -> None: + if find_spec("transformers_neuronx") is not None: + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + input_block_ids: List[int] = [] + + seq_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) + + input_tokens.append(prompt_tokens) + input_positions.append(list(range(seq_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == 1 + input_block_ids.append(block_table[0]) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + # Process multi-modal data + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata.mm_processor_kwargs, + ) + multi_modal_inputs_list.append(mm_kwargs) + + max_seq_len = max(seq_lens) + assert max_seq_len > 0 + input_tokens = make_tensor_with_pad(input_tokens, + pad=0, + max_len=max_seq_len, + dtype=torch.long, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + pad=0, + max_len=max_seq_len, + dtype=torch.long, + device=self.device) + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + input_block_ids: List[int] = [] + context_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == 1 + input_block_ids.append(block_table[0]) + + input_tokens = make_tensor_with_pad(input_tokens, + pad=0, + max_len=1, + dtype=torch.long, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + pad=0, + max_len=1, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) + + return input_tokens, input_positions, input_block_ids + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForNeuron: + return ModelInputForNeuron.from_broadcasted_tensor_dict(tensor_dict) + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForNeuron: + multi_modal_kwargs = None + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_block_ids, seq_lens, + multi_modal_kwargs + ) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + input_block_ids) = self._prepare_decode(seq_group_metadata_list) + seq_lens = None + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since neuron worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + self.pin_memory, + generators=self.get_generators(finished_requests_ids)) + + if not self._on_device_sampling_disabled: + # Once the request IDs are changed in current iteration, we will + # update the on-device sampling parameters. + current_batch_request_ids = [ + seq_group_meta_data.request_id + for seq_group_meta_data in seq_group_metadata_list + ] + if current_batch_request_ids != self._previous_batch_request_ids: + self._update_neuron_sampling_params(sampling_metadata) + self._previous_batch_request_ids = current_batch_request_ids + + return ModelInputForNeuron(input_tokens=input_tokens, + input_positions=input_positions, + input_block_ids=input_block_ids, + sampling_metadata=sampling_metadata, + multi_modal_kwargs=multi_modal_kwargs) + + def _update_neuron_sampling_params(self, + sampling_metadata: SamplingMetadata): + # Update Neuron sampling parameters (GenerationConfig in Neuron) + current_sampling_params = self.model_config.neuron_sampling_params + assert current_sampling_params is not None, ( + f"Failed to update sampling_params, " + f"current sampling params is {current_sampling_params}") + + top_k = current_sampling_params.top_k + top_p = current_sampling_params.top_p + temperature = current_sampling_params.temperature + for index, sequence_group_to_sample in enumerate( + sampling_metadata.seq_groups): + top_k[index] = self._convert_to_neuron_top_k( + sequence_group_to_sample.sampling_params.top_k) + top_p[index] = sequence_group_to_sample.sampling_params.top_p + temperature[index] = \ + sequence_group_to_sample.sampling_params.temperature + + self.model.model.update_generation_config(current_sampling_params) + + def _convert_to_neuron_top_k(self, top_k: int) -> int: + if top_k < 0 or top_k > self._MAX_NEURON_SAMPLING_TOP_K: + return self._MAX_NEURON_SAMPLING_TOP_K + return top_k + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForNeuron, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "NeuronModelRunner does not support multi-step execution.") + + hidden_states = self.model( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + input_block_ids=model_input.input_block_ids, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device), + ) + + # Compute the logits only if the on-device sampling is turned off as + # on-device sampling outputs the token ids. + if self._on_device_sampling_disabled: + logits = self.model.compute_logits(hidden_states, + model_input.sampling_metadata) + else: + logits = hidden_states + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + return [output] + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py new file mode 100644 index 0000000..fff14d6 --- /dev/null +++ b/vllm/worker/neuron_worker.py @@ -0,0 +1,127 @@ +"""A Neuron worker class.""" +from typing import List, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.worker.neuron_model_runner import NeuronModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) + + +class NeuronWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): + """A worker class that executes the model on a group of neuron cores. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner: NeuronModelRunner = NeuronModelRunner( + model_config, parallel_config, scheduler_config, device_config) + self.is_driver_worker = True + + def init_device(self) -> None: + self.init_distributed_environment() + + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + + # Swap not yet supported with Neuron backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + @property + def do_metadata_broadcast(self) -> bool: + return False + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return None + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + return WorkerInput(num_seq_groups=len( + execute_model_req.seq_group_metadata_list), ) + + def execute_worker(self, worker_input: WorkerInput) -> None: + pass + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError + + def init_distributed_environment(self): + """Neuron uses transformers-neuronx for tensor parallelism. + + vLLM still needs the environment inited when TP/PP > 1 + """ + init_distributed_environment( + world_size=1, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + 1, + 1, + ) diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py new file mode 100644 index 0000000..760b184 --- /dev/null +++ b/vllm/worker/openvino_model_runner.py @@ -0,0 +1,353 @@ +from typing import List, NamedTuple, Optional, Tuple + +import openvino as ov +import torch +from torch import nn + +from vllm.attention import get_attn_backend +from vllm.attention.backends.openvino import OpenVINOAttentionMetadata +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.openvino import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs) +from vllm.sequence import SequenceGroupMetadata + +logger = init_logger(__name__) + + +class ModelInput(NamedTuple): + input_tokens: torch.Tensor + input_positions: torch.Tensor + attn_metadata: Optional[OpenVINOAttentionMetadata] + seq_lens: List[int] + query_lens: List[int] + multi_modal_kwargs: BatchedTensorInputs + + @classmethod + def empty(cls, device): + return ModelInput(input_tokens=torch.empty(0, device=device), + input_positions=torch.empty(0, device=device), + attn_metadata=None, + seq_lens=[], + query_lens=[], + multi_modal_kwargs={}) + + +class OpenVINOModelRunner: + + def __init__( + self, + ov_core: ov.Core, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + *args, + **kwargs, + ): + self.ov_core = ov_core + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.multimodal_config = multimodal_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Multi-modal data support + self.multi_modal_input_mapper = MULTIMODAL_REGISTRY \ + .create_input_mapper(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + def load_model(self) -> None: + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + kv_cache_dtype=self.kv_cache_dtype, + ov_core=self.ov_core) + + def _prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> ModelInput: + """Prepare the model input based on a given sequence group. + + The API assumes seq_group_metadata_list is sorted by prefill -> decode. + + The result tensors and data structure also batches input in prefill + -> decode order. For example, + + - input_tokens[:num_prefill_tokens] contains prefill tokens. + - input_tokens[num_prefill_tokens:] contains decode tokens. + """ + input_tokens: List[int] = [] + input_positions: List[int] = [] + + seq_lens: List[int] = [] + past_lens: List[int] = [] + query_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + + subsequence_begins: List[int] = [] + block_indices: List[int] = [] + block_indices_begins: List[int] = [] + + # initialize beginning of prefix sums + subsequence_begins.append(0) + block_indices_begins.append(0) + + if len(seq_group_metadata_list) == 0: + return ModelInput.empty(self.device) + + for seq_group_metadata in seq_group_metadata_list: + seq_ids = list(seq_group_metadata.seq_data.keys()) + is_prompt = seq_group_metadata.is_prompt + + for seq_id in seq_ids: + computed_block_nums = seq_group_metadata.computed_block_nums + if (self.scheduler_config is not None + and self.scheduler_config.chunked_prefill_enabled + and not (computed_block_nums is None + or computed_block_nums == [])): + raise RuntimeError( + "chunked prefill cannot be used with prefix caching " + "now.") + + seq_data = seq_group_metadata.seq_data[seq_id] + if is_prompt: + computed_len = seq_data.get_num_computed_tokens() + else: + # get_num_computed_tokens is incorrect for spec decoding. + # So, we should have a special logic here. + # TODO(sang): Fix it. + computed_len = seq_data.get_len() - 1 + + seq_len = min( + seq_data.get_len(), + computed_len + seq_group_metadata.token_chunk_size, + ) + if is_prompt: + tokens = seq_data.get_token_ids()[computed_len:seq_len] + else: + # Optimization. get_token_ids requires the entire copy of + # tokens. + tokens = [seq_data.get_last_token_id()] + + # Prefix cache was hit. + # Prefix is not supported with sliding_window + prefix_cache_hit = (computed_block_nums is not None + and len(computed_block_nums) > 0 + and self.sliding_window is None + and is_prompt) + + mm_data = seq_group_metadata.multi_modal_data + if mm_data: + mm_kwargs = self.multi_modal_input_mapper( + mm_data, + mm_processor_kwargs=seq_group_metadata. + mm_processor_kwargs, + ) + multi_modal_inputs_list.append(mm_kwargs) + + block_table = seq_group_metadata.block_tables[seq_id] + # TODO(sang): Combine chunked prefill and prefix caching by + # only allowing multiple of block_size chunk size. + # NOTE: This only works for oooooooxxx style attention. + if prefix_cache_hit: + assert computed_block_nums is not None + computed_len = len(computed_block_nums) * self.block_size + tokens = tokens[computed_len:] + elif (self.scheduler_config.chunked_prefill_enabled + or not is_prompt): + if seq_group_metadata.block_tables is not None: + # chunked prefill or decode + block_table = seq_group_metadata.block_tables[seq_id] + if self.sliding_window is not None: + # chunked prefill doesn't support sliding window. + assert not self.scheduler_config.chunked_prefill_enabled # noqa: E501 + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + else: + # Only happens when memory profiling runs. + block_table = [] + else: + # prompt phase w/o prefix_caching, chunked_prefill + pass + + block_indices.extend(block_table) + block_indices_begins.append(block_indices_begins[-1] + + len(block_table)) + + # TODO(sang): This is a hack to make sliding window work with + # paged attn. We can remove it if we make paged attn kernel + # to properly handle slinding window attn. + if self.sliding_window is not None and not is_prompt: + seq_len = min(seq_len, self.sliding_window) + computed_len = seq_len - 1 + + seq_lens.append(seq_len) + + query_len = seq_len - computed_len + query_lens.append(query_len) + + input_tokens.extend(tokens) + input_positions.extend(list(range(computed_len, seq_len))) + + past_lens.append(computed_len) + subsequence_begins.append(subsequence_begins[-1] + query_len) + + if is_prompt: + assert len(seq_ids) == 1 + else: + assert ( + query_len == 1 + ), "seq_len: {}, computed_len: {}, query_len: {}".format( + seq_len, computed_len, query_len) + + max_query_len = max(query_lens) + assert max_query_len > 0, "query_lens: {}".format(query_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + + past_lens_tensor = torch.tensor(past_lens, + dtype=torch.int32, + device=self.device) # type: ignore + subsequence_begins_tensor = torch.tensor( + subsequence_begins, dtype=torch.int32, + device=self.device) # type: ignore + block_indices_tensor = torch.tensor(block_indices, + dtype=torch.int32, + device=self.device) # type: ignore + block_indices_begins_tensor = torch.tensor( + block_indices_begins, dtype=torch.int32, + device=self.device) # type: ignore + + max_context_len = max(seq_lens) + max_context_len_tensor = torch.tensor( + max_context_len, dtype=torch.int32, + device=self.device) # type: ignore + + attn_metadata = self.attn_backend.make_openvino_metadata( + past_lens=past_lens_tensor, + subsequence_begins=subsequence_begins_tensor, + block_indices=block_indices_tensor, + block_indices_begins=block_indices_begins_tensor, + max_context_len=max_context_len_tensor, + ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return ModelInput( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + multi_modal_kwargs=multi_modal_kwargs, + ) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, OpenVINOAttentionMetadata, + SamplingMetadata, BatchedTensorInputs]: + # Prepare input tensors. + ( + input_tokens, + input_positions, + attn_metadata, + seq_lens, + query_lens, + multi_modal_kwargs, + ) = self._prepare_model_input(seq_group_metadata_list) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + query_lens, + self.device, + pin_memory=False, + ) + + return ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + multi_modal_kwargs, + ) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple["ov.Tensor", "ov.Tensor"]], + ) -> Optional[SamplerOutput]: + ( + input_tokens, + input_positions, + attn_metadata, + sampling_metadata, + multi_modal_kwargs, + ) = self.prepare_input_tensors(seq_group_metadata_list) + + model_executable = self.model + execute_model_kwargs = { + "input_ids": + input_tokens, + "positions": + input_positions, + "kv_caches": + kv_caches, + "attn_metadata": + attn_metadata, + **MultiModalInputs.as_kwargs(multi_modal_kwargs or {}, + device=self.device), + } + + hidden_states = model_executable(**execute_model_kwargs) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + return output diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py new file mode 100644 index 0000000..24425fe --- /dev/null +++ b/vllm/worker/openvino_worker.py @@ -0,0 +1,611 @@ +"""An OpenVINO worker class.""" +from typing import Any, Dict, List, Optional, Tuple + +import openvino as ov +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, MultiModalConfig, ParallelConfig, + SchedulerConfig) +from vllm.distributed import (broadcast_tensor_dict, + ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.executor.openvino_executor import is_openvino_cpu +from vllm.inputs import INPUT_REGISTRY +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sampling_params import SamplingParams +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata +from vllm.worker.openvino_model_runner import OpenVINOModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + +logger = init_logger(__name__) + + +class OpenVINOCacheEngine: + """Manages the KV cache for OpenVINO backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig, + ov_core: ov.Core, + ov_device: str, + ) -> None: + assert device_config.device_type == "openvino" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + if device_config.device.type == "cpu" and \ + cache_config.cache_dtype == ov.Type.u8: + # Scale, zero point and quantized data will be stored together. + # The layout for per token per head: + # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + # so, we have to extend head_size by 8, which is sizeof(float) + # for scale and sizeof(float) for zeropoint + self.head_size += 8 + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for OpenVINO backend with a CPU target device, because we want + # to reuse KV cache management in the scheduler. + self.num_device_blocks = cache_config.num_gpu_blocks + self.num_swap_blocks = cache_config.num_cpu_blocks + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.head_size, + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Initialize the cache. + self.kv_cache: List[Tuple[ov.Tensor, + ov.Tensor]] = self._allocate_kv_cache( + self.num_device_blocks, ov_core, + ov_device) + + # Initialize the swap. + self.swap_cache: List[Tuple[ov.Tensor, + ov.Tensor]] = self._allocate_swap_cache( + self.num_swap_blocks, ov_device) + + def _allocate_kv_cache( + self, + num_blocks: int, + ov_core: ov.Core, + ov_device: str, + ) -> List[Tuple[ov.Tensor, ov.Tensor]]: + """Allocates KV cache.""" + k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] + kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] + + if is_openvino_cpu(): + for _ in range(self.num_layers): + key_blocks = ov.Tensor(self.cache_config.cache_dtype, + k_block_shape) + value_blocks = ov.Tensor(self.cache_config.cache_dtype, + v_block_shape) + kv_cache.append((key_blocks, value_blocks)) + else: + # Update key_cache shape: + k_block_shape = (v_block_shape[0], v_block_shape[1], + v_block_shape[3], v_block_shape[2]) + + remote_context = ov_core.get_default_context(ov_device) + + for _ in range(self.num_layers): + key_blocks = \ + remote_context.create_tensor(self.cache_config.cache_dtype, + ov.Shape(k_block_shape), + {}) + + value_blocks = \ + remote_context.create_tensor(self.cache_config.cache_dtype, + ov.Shape(v_block_shape), + {}) + + kv_cache.append((key_blocks, value_blocks)) + + return kv_cache + + def _allocate_swap_cache( + self, + num_blocks: int, + ov_device: str, + ) -> List[Tuple[ov.Tensor, ov.Tensor]]: + """Allocates swap cache.""" + k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] + swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] + + if num_blocks == 0: + return swap_cache + + assert not is_openvino_cpu(), \ + "CPU device isn't supposed to have swap cache" + + # Update key_cache shape: + k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3], + v_block_shape[2]) + + for _ in range(self.num_layers): + key_blocks = ov.Tensor(self.cache_config.cache_dtype, + k_block_shape) + value_blocks = ov.Tensor(self.cache_config.cache_dtype, + v_block_shape) + swap_cache.append((key_blocks, value_blocks)) + + return swap_cache + + def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None: + for i in range(self.num_layers): + for swap_tensor, kv_tensor in zip(self.swap_cache[i], + self.kv_cache[i]): + self.attn_backend.swap_blocks(swap_tensor, kv_tensor, + src_to_dst) + + def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None: + for i in range(self.num_layers): + for swap_tensor, kv_tensor in zip(self.swap_cache[i], + self.kv_cache[i]): + self.attn_backend.swap_blocks(kv_tensor, swap_tensor, + src_to_dst) + + def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None: + if (len(src_to_dsts) > 0): + self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: ov.Type, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + if cache_dtype == ov.Type.u8: + # Scale, zero point and quantized data will be stored together. + # The layout for per token per head: + # |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)| # noqa: E501 + # so, we have to extend head_size by 8, which is sizeof(float) + # for scale and sizeof(float) for zeropoint + head_size += 8 + + key_cache_block = block_size * num_kv_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = cache_dtype.size + return dtype_size * total + + +class OpenVINOWorker(LoraNotSupportedWorkerBase): + """A worker class that executes the model on OpenVINO backend. + + Each worker is associated with a single OpenVINO device. The worker is + responsible for maintaining the KV cache and executing the model on the + OpenVINO backend. + """ + + def __init__( + self, + ov_core: ov.Core, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + multimodal_config: Optional[MultiModalConfig] = None, + kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, + is_driver_worker: bool = False, + ) -> None: + self.ov_core = ov_core + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.multimodal_config = multimodal_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + + init_cached_hf_modules() + self.model_runner = OpenVINOModelRunner( + self.ov_core, + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + multimodal_config=self.multimodal_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: OpenVINOCacheEngine + self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] + + def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured + KV cache space. + """ + # For OpenVINO backend, in case of CPU device, the block number will be + # calculated based on the openvino_kvcache_space_bytes. + cache_block_size = self.get_cache_block_size_bytes() + kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes + + if is_openvino_cpu(): + num_device_blocks = int(kvcache_space_bytes // cache_block_size) + num_swap_blocks = 0 + else: + if kvcache_space_bytes > 0: + logger.info("KV_CACHE size was explicitly configured via " + "VLLM_OPENVINO_KVCACHE_SPACE environment " + "variable, ignoring profiling run.") + kv_cache_size = kvcache_space_bytes + else: + try: + kv_cache_size = self.profile_run() + except Exception as err: + raise RuntimeError( + "The error occurred during profile run. This might be " + "due to insufficient GPU memory. Consider decreasing " + "`max_model_len` to limit the maximum simultaneously " + "processed tokens.") from err + + num_device_blocks = int(kv_cache_size // cache_block_size) + num_swap_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + + return num_device_blocks, num_swap_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Swappable CPU memory is only + supported on GPU. + + For CPU, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + + num_device_blocks = num_gpu_blocks + num_swap_blocks = num_cpu_blocks + + if is_openvino_cpu(): + assert (num_swap_blocks == 0 + ), f"{type(self)} does not support swappable cache for CPU" + + self._validate_num_blocks(num_device_blocks) + self.cache_config.num_gpu_blocks = num_device_blocks + self.cache_config.num_cpu_blocks = num_swap_blocks + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_blocks(self, num_blocks: int) -> None: + """Raise errors if the num_blocks is invalid.""" + if num_blocks <= 0: + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_OPENVINO_KVCACHE_SPACE` or decreasing `max_model_len` " + "when initializing the engine.") + + def _init_cache_engine(self) -> None: + ov_device = envs.VLLM_OPENVINO_DEVICE + self.cache_engine = OpenVINOCacheEngine( + self.cache_config, + self.model_config, + self.parallel_config, + self.device_config, + self.ov_core, + ov_device, + ) + self.kv_cache = self.cache_engine.kv_cache + self.model_runner.block_size = self.cache_engine.block_size + + assert self.kv_cache is not None + + # Populate the cache to warmup the memory + if is_openvino_cpu(): + for key_cache, value_cache in self.kv_cache: + key_cache.data[:] = 0 + value_cache.data[:] = 0 + + def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None: + self.cache_engine.swap_in(src_to_dst) + + def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None: + self.cache_engine.swap_out(src_to_dst) + + def cache_copy( + self, + blocks_to_copy: List[Tuple[int, int]], + ) -> None: + self.cache_engine.copy(blocks_to_copy) # type: ignore + + @torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> List[SamplerOutput]: + if execute_model_req is None: + seq_group_metadata_list = None + else: + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups: int = len(seq_group_metadata_list) + assert execute_model_req is not None + blocks_to_copy = execute_model_req.blocks_to_copy + blocks_to_swap_in = execute_model_req.blocks_to_swap_in + blocks_to_swap_out = execute_model_req.blocks_to_swap_out + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_copy": execute_model_req.blocks_to_copy, + "blocks_to_swap_in": execute_model_req.blocks_to_swap_in, + "blocks_to_swap_out": execute_model_req.blocks_to_swap_out, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_copy = data["blocks_to_copy"] + blocks_to_swap_in = data["blocks_to_swap_in"] + blocks_to_swap_out = data["blocks_to_swap_out"] + + if is_openvino_cpu(): + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + else: + self.cache_swap_in(blocks_to_swap_in) + self.cache_swap_out(blocks_to_swap_out) + + self.cache_copy(blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.kv_cache) + + # OpenVINO worker only supports single-step execution. + return [output] + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + ) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block.""" + return OpenVINOCacheEngine.get_cache_block_size( + self.cache_config.block_size, + self.cache_config.cache_dtype, + self.model_config, + self.parallel_config, + ) + + def profile_run(self) -> int: + ov_device = envs.VLLM_OPENVINO_DEVICE + + assert not is_openvino_cpu(), \ + "CPU device isn't supposed to use profile run." + + import openvino.properties.device as device + import openvino.properties.intel_gpu as intel_gpu + + ov_core = self.ov_core + cache_config = self.cache_config + model_config = self.model_config + parallel_config = self.parallel_config + device_config = self.device_config + input_registry = INPUT_REGISTRY + mm_registry = MULTIMODAL_REGISTRY + mm_registry.init_mm_limits_per_prompt(model_config) + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + def model_profile_run(): + top_k = model_config.get_vocab_size() - 1 + sampling_params = SamplingParams(top_p=0.99, top_k=top_k) + + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + tmp_cache_config = CacheConfig(cache_config.block_size, + cache_config.gpu_memory_utilization, + cache_config.swap_space_bytes, + "auto") + tmp_cache_config.num_gpu_blocks = 1 + tmp_cache_config.num_cpu_blocks = 0 + tmp_cache_config.cache_dtype = cache_config.cache_dtype + + profiling_cache_engine = OpenVINOCacheEngine( + tmp_cache_config, model_config, parallel_config, device_config, + ov_core, ov_device) + + # Profile memory usage with max_num_sequences sequences and the + # total # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + block_size = cache_config.block_size + seq_num_blocks = (seq_len + block_size - 1) // block_size + + seq_data, dummy_multi_modal_data = input_registry \ + .dummy_data_for_profiling(model_config, + seq_len, + mm_registry) + + block_tables = [[0] * seq_num_blocks] * max_num_seqs + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=None, + multi_modal_data=dummy_multi_modal_data) + seqs.append(seq) + + self.model_runner.block_size = tmp_cache_config.block_size + + # Run the model with the dummy inputs. + self.model_runner.execute_model(seqs, + profiling_cache_engine.kv_cache) + + # explicitly delete temporary KV cache manager to free KV cache + # when real inputs will be passed to OV + del profiling_cache_engine + + logger.info( + "Start profiling run with dummy inputs to evaluate " + "memory usage for %s. It might take a while.", ov_device) + + model_profile_run() + + gpu_device_type = ov_core.get_property(ov_device, device.type) + memory_statistics = \ + ov_core.get_property(ov_device, intel_gpu.memory_statistics) + memory_utilization = cache_config.gpu_memory_utilization + + if gpu_device_type == device.Type.INTEGRATED and \ + memory_utilization >= 0.9: + logger.warning( + "iGPU is used with high gpu_memory_utilization=%f " + "value. This may cause low performance due to " + "occupying the majority of available system " + "memory. Please consider decreasing " + "gpu_memory_utilization or explicitly setting" + "`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment " + "variable.", memory_utilization) + + # sum up all used device memory + device_memory_types = ["cl_mem", "usm_device"] + used_device_mem = \ + sum(memory_statistics.get(key, 0) for key in device_memory_types) + + if gpu_device_type == device.Type.INTEGRATED: + used_device_mem += memory_statistics.get("usm_host", 0) + + # there could be unaccounted extra memory reserved by kernels, kept + # in memory pools, etc + # therefore, add a threshold to account for this + used_memory_threshold = 1.1 + used_device_mem *= used_memory_threshold + + total_device_memory = \ + ov_core.get_property(ov_device, intel_gpu.device_total_mem_size) + + def format_memory_size(size) -> str: + units = ["B", "KB", "MB", "GB"] + unit_index = 0 + + while size > 1024 and unit_index < len(units) - 1: + size /= 1024 + unit_index += 1 + + return f"{size:.2f} {units[unit_index]}" + + total_device_memory_str = \ + format(format_memory_size(total_device_memory)) + used_device_memory_str = \ + format(format_memory_size(used_device_mem)) + + logger.info( + "Total %s memory: %s. " + "Amount of memory required to run the model with " + "max_num_batched_tokens=%d: %s.", ov_device, + total_device_memory_str, + self.scheduler_config.max_num_batched_tokens, + used_device_memory_str) + + if used_device_mem >= total_device_memory: + raise RuntimeError( + f"The required memory size {used_device_memory_str} for model " + "is higher than the total available device " + "memory {total_device_memory_str}. Please consider to " + "decrease `max_num_batched_tokens` or increase " + "`gpu_memory_utilization`") + + return total_device_memory * memory_utilization - used_device_mem diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py new file mode 100644 index 0000000..c13e95f --- /dev/null +++ b/vllm/worker/tpu_model_runner.py @@ -0,0 +1,842 @@ +import time +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) +from unittest.mock import patch + +import numpy as np +import torch +import torch.nn as nn +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, + _add_attn_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +# Here we utilize the behavior that out-of-bound index is ignored. +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. +_PAD_SLOT_ID = 1_000_000_000 +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False +# FIXME(woosuk): A temporary hack to support `n > 1`. +# This can significantly affect the performance if too large. +_MAX_NUM_SAMPLES = 128 + + +@dataclass(frozen=True) +class ModelInputForTPU(ModelRunnerInputBase): + token_ids: torch.Tensor + position_ids: torch.Tensor + attn_metadata: AttentionMetadata + input_lens: torch.Tensor + t: torch.Tensor + p: torch.Tensor + num_samples: int + n: List[int] + seq_groups: List[List[int]] + is_first_multi_step: bool = True + is_last_step: bool = True + virtual_engine: int = 0 + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + tensor_dict = { + "token_ids": self.token_ids, + "position_ids": self.position_ids, + "input_lens": self.input_lens, + "t": self.t, + "p": self.p, + "num_samples": self.num_samples, + "n": self.n, + "seq_groups": self.seq_groups, + "is_first_multi_step": self.is_first_multi_step, + "is_last_step": self.is_last_step, + "virtual_engine": self.virtual_engine, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["ModelInputForTPU"], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForTPU": + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + is_driver_worker: bool = False, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + + self.block_size = self.cache_config.block_size + self.max_num_blocks_per_seq = (self.model_config.max_model_len // + self.block_size) + self.block_tables = np.zeros( + (self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq), + dtype=np.int32) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + False, + ) + self.cached_step_outputs: List[torch.Tensor] = [] + + def load_model(self) -> None: + self.device = self.device_config.device + + # NOTE(woosuk): While the executor assigns the TP ranks to the worker + # process, the ranks can be different from the ranks internally assigned + # by the xm runtime. Therefore, there is a mismatch in the rank + # assignment between the gloo (cpu) runtime and the xm (tpu) runtime. + # This is not a problem in linear layers because all-reduce is + # rank-agnostic. However, it matters for all-gather as the ranks + # determine the order of concatenating the output tensors. + # As a workaround, we use the xm's rank assignment only when loading + # the embedding weights. + xm_tp_rank = xr.global_ordinal() + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding." + "get_tensor_model_parallel_rank", + return_value=xm_tp_rank): + model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + parallel_config=self.parallel_config, + cache_config=self.cache_config, + scheduler_config=self.scheduler_config, + lora_config=None, + ) + model = model.eval() + xm.wait_device_ops() + self.model = ModelWrapper(model) + + def _dummy_run( + self, + batch_size: int, + seq_len: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + is_prompt: bool, + ) -> None: + if is_prompt: + seq_len = (seq_len + 15) // 16 * 16 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + block_tables=None, + context_lens=None, + ) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + else: + assert seq_len == 1 + token_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + position_ids = torch.zeros((batch_size, seq_len), + dtype=torch.int32, + device=self.device) + slot_mapping = torch.zeros((batch_size, seq_len), + dtype=torch.int64, + device=self.device) + block_tables = torch.zeros( + (batch_size, self.max_num_blocks_per_seq), + dtype=torch.int32, + device=self.device) + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + input_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size * seq_len, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + + # NOTE(woosuk): There are two stages of compilation: torch.compile and + # XLA compilation. Using `mark_dynamic` can reduce the torch.compile + # overhead by reusing the FX graph for different shapes. + # However, the XLA graph will still require static shapes and needs to + # be re-compiled for every different shapes. This overhead is inevitable + # in the first run, but can be skipped afterwards as we cache the XLA + # graphs in the disk (VLLM_XLA_CACHE_PATH). + if is_prompt: + # Prefll + torch._dynamo.mark_dynamic(token_ids, 1) + torch._dynamo.mark_dynamic(position_ids, 1) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1) + else: + # Decode + torch._dynamo.mark_dynamic(token_ids, 0) + torch._dynamo.mark_dynamic(position_ids, 0) + torch._dynamo.mark_dynamic(input_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0) + torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0) + torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0) + torch._dynamo.mark_dynamic(t, 0) + torch._dynamo.mark_dynamic(p, 0) + # Dummy run. + self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + num_samples, + kv_caches, + is_prompt=is_prompt) + + def warmup_model( + self, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> None: + # Prefill + logger.info("Compiling the model with different input shapes...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while True: + self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if seq_len >= self.model_config.max_model_len: + break + num_tokens = batch_size * seq_len + if num_tokens >= self.scheduler_config.max_num_batched_tokens: + break + seq_len = seq_len * 2 + + end = time.time() + logger.info("Compilation for prefill done in %.2f s.", end - start) + + # Decode + start = time.time() + seq_len = 1 + batch_size = 8 # Must be in sync with _get_padded_batch_size() + while True: + self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) + + if batch_size >= self.scheduler_config.max_num_seqs: + break + batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2 + + end = time.time() + logger.info("Compilation for decode done in %.2f s.", end - start) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + prompt_lens: List[int] = [] + slot_mapping: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + # Could include output tokens when a request is preempted. + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + + input_tokens.extend(prompt_tokens) + input_positions.extend(list(range(prompt_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + for i in range(prompt_len): + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + # Add paddings to EACH prompt to the smallest power of 2 that is + # greater than or equal to the prompt length. + # We pad the seq_len to reduce the compilation overhead. + # We execute each prompt individually (i.e., with batch_size 1) + # because the FlashAttention kernel does not support ragged inputs. + # TODO(woosuk): Use SplashAttention to support ragged inputs. + padded_prompt_len = _get_padded_prefill_len(prompt_len) + num_paddings = padded_prompt_len - prompt_len + input_tokens += [0] * num_paddings + input_positions += [0] * num_paddings + slot_mapping += [_PAD_SLOT_ID] * num_paddings + + assert len(prompt_lens) > 0 + num_prefills = len(prompt_lens) + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + prompt_lens = torch.tensor(prompt_lens, + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=num_prefills, + num_prefill_tokens=0, # NOTE: This is not used. + num_decode_tokens=0, + slot_mapping=slot_mapping, + block_tables=None, + context_lens=None, + ) + return input_tokens, input_positions, attn_metadata, prompt_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + + batch_idx = 0 + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + self.block_tables[batch_idx, :len(block_table)] = block_table + batch_idx += 1 + + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + + batch_size = _get_padded_batch_size(batch_idx) + num_paddings = batch_size - batch_idx + input_tokens = input_tokens + [[0]] * num_paddings + input_positions = input_positions + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings + context_lens = context_lens + [0] * num_paddings + + input_tokens = torch.tensor(input_tokens, + dtype=torch.int32, + device="cpu") + input_positions = torch.tensor(input_positions, + dtype=torch.int32, + device="cpu") + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.int64, + device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device="cpu") + input_lens = torch.tensor([1] * batch_size, + dtype=torch.int32, + device="cpu") + attn_metadata = self.attn_backend.make_metadata( + num_prefills=0, + num_prefill_tokens=0, + num_decode_tokens=batch_size, + slot_mapping=slot_mapping, + block_tables=block_tables, + context_lens=context_lens, + ) + return input_tokens, input_positions, attn_metadata, input_lens + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + padded_batch_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[int]]: + assert len(seq_group_metadata_list) > 0 + t = [] + p = [] + n = [] + for seq_group_metadata in seq_group_metadata_list: + sampling_params = seq_group_metadata.sampling_params + t.append(sampling_params.temperature) + if sampling_params.top_p != 1 and not _ENABLE_TOP_P: + raise NotImplementedError( + "Top-p sampling is currently disabled for the TPU backend " + "due to performance issues.") + p.append(sampling_params.top_p) + if sampling_params.top_k != -1: + raise NotImplementedError( + "Top-k sampling is currently disabled for the TPU backend " + "due to performance issues.") + if sampling_params.n > _MAX_NUM_SAMPLES: + raise NotImplementedError( + f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " + "backend.") + n.append(sampling_params.n) + if sampling_params.logprobs is not None: + raise NotImplementedError( + "logprobs is not currently supported by the TPU backend.") + if sampling_params.prompt_logprobs is not None: + raise NotImplementedError( + "prompt_logprobs is not currently supported by the TPU " + "backend.") + + # Repeat the sampling params if the seq group has multiple seqs. + num_seqs = len(seq_group_metadata.seq_data) + t += [t[-1]] * (num_seqs - 1) + p += [p[-1]] * (num_seqs - 1) + n += [n[-1]] * (num_seqs - 1) + + num_paddings = padded_batch_size - len(t) + t += [1.0] * num_paddings + p += [1.0] * num_paddings + + t = torch.tensor(t, dtype=torch.float32, device="cpu") + p = torch.tensor(p, dtype=torch.float32, device="cpu") + return t, p, n + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None, + ) -> ModelInputForTPU: + del finished_requests_ids # Unused. + assert virtual_engine == 0 + assert len(seq_group_metadata_list) > 0 + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + inputs = self._prepare_prompt(seq_group_metadata_list) + else: + inputs = self._prepare_decode(seq_group_metadata_list) + input_tokens, input_positions, attn_metadata, input_lens = inputs + padded_batch_size = input_tokens.shape[0] + t, p, n = self._prepare_sample(seq_group_metadata_list, + padded_batch_size) + num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + + seq_groups = [ + list(metadata.seq_data.keys()) + for metadata in seq_group_metadata_list + ] + return ModelInputForTPU(input_tokens, input_positions, attn_metadata, + input_lens, t, p, num_samples, n, seq_groups) + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU: + model_input = ModelInputForTPU.from_broadcasted_tensor_dict( + tensor_dict, attn_backend=self.attn_backend) + return model_input + + @torch.no_grad() + def execute_model( + self, + model_input: ModelInputForTPU, + kv_caches: Optional[List[Any]], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> List[SamplerOutput]: + assert intermediate_tensors is None + if not model_input.is_first_multi_step: + if not model_input.is_last_step: + return [] + + use_async_out_proc = model_input.async_callback is not None + sampler_outputs = [] + num_outputs = len(self.cached_step_outputs) + for i in range(num_outputs): + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + sampler_outputs.append(sampler_output) + + if i < num_outputs - 1 and use_async_out_proc: + assert model_input.async_callback is not None + ctx = model_input.async_callback.keywords[ # type: ignore + "ctx"] + ctx.append_output( + outputs=[sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False, + is_first_step_output=i == 0) + model_input.async_callback() + if use_async_out_proc: + return [sampler_outputs[-1]] + else: + return sampler_outputs + + is_prompt = model_input.attn_metadata.num_prefills > 0 + if is_prompt: + assert num_steps == 1 + # NOTE(woosuk): Since the FlashAttention kernel does not support + # ragged inputs, we split the prompts into different batches and + # process them separately. This is a temporary hack that should be + # optimized by using SplashAttention. + orig_slot_mapping = model_input.attn_metadata.slot_mapping + batch_size = model_input.input_lens.shape[0] + start_idx = 0 + next_token_ids = [] + for i in range(batch_size): + # Get the actual prefill_len. + prefill_len = model_input.input_lens[i:i + 1].item() + prefill_len = _get_padded_prefill_len(prefill_len) + end_idx = start_idx + prefill_len + + token_ids = model_input.token_ids[None, start_idx:end_idx].to( + self.device) + position_ids = model_input.position_ids[None, + start_idx:end_idx].to( + self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.num_prefills = 1 + attn_metadata.slot_mapping = orig_slot_mapping[ + None, start_idx:end_idx].to(self.device) + input_lens = model_input.input_lens[i:i + 1].to(self.device) + t = model_input.t[i:i + 1].to(self.device) + p = model_input.p[i:i + 1].to(self.device) + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=True) + next_token_ids.append(output_token_ids[0]) + start_idx = end_idx + + if model_input.async_callback is not None: + model_input.async_callback() + # Retrieve the outputs to CPU. + next_token_ids = [ + output_token_ids.cpu().tolist() + for output_token_ids in next_token_ids + ] + + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support advanced sampling parameters such as logprobs. + zero_logprob = Logprob(0.0) + sampler_outputs = [] + for i, seq_group in enumerate(model_input.seq_groups): + seq_ids = seq_group + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + seq_outputs = [] + for j in range(model_input.n[i]): + next_token_id = next_token_ids[i][j] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + sampler_outputs.append( + CompletionSequenceGroupOutput(seq_outputs, None)) + return [SamplerOutput(sampler_outputs)] + else: + token_ids = model_input.token_ids.to(self.device) + position_ids = model_input.position_ids.to(self.device) + attn_metadata = model_input.attn_metadata + attn_metadata.slot_mapping = attn_metadata.slot_mapping.to( + self.device) + attn_metadata.block_tables = attn_metadata.block_tables.to( + self.device) + attn_metadata.context_lens = attn_metadata.context_lens.to( + self.device) + t = model_input.t.to(self.device) + p = model_input.p.to(self.device) + input_lens = model_input.input_lens.to(self.device) + for i in range(num_steps): + slot_mapping = attn_metadata.slot_mapping + output_token_ids = self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + model_input.num_samples, + kv_caches, + is_prompt=False) + self.cached_step_outputs.append(output_token_ids) + + if i < num_steps - 1: + # Prepare the inputs for the next step. + token_ids = output_token_ids.unsqueeze(dim=1).int() + position_ids = position_ids + 1 + attn_metadata.context_lens = attn_metadata.context_lens + 1 + + block_tables = attn_metadata.block_tables + block_number = block_tables.gather( + 1, + position_ids.long() // self.block_size) + block_offset = position_ids % self.block_size + + is_padding = slot_mapping == _PAD_SLOT_ID + slot_mapping = block_number * self.block_size + block_offset + slot_mapping = slot_mapping.long() + slot_mapping = torch.where(is_padding, _PAD_SLOT_ID, + slot_mapping) + attn_metadata.slot_mapping = slot_mapping + + if model_input.async_callback is not None: + model_input.async_callback() + + if num_steps > 1: + return [] + # Retrieve the outputs to CPU. + next_token_ids = self.cached_step_outputs.pop(0) + next_token_ids = next_token_ids.cpu().tolist() + sampler_output = _make_decode_output(next_token_ids, + model_input.seq_groups) + return [sampler_output] + + +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model: nn.Module): + self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) + + def forward( + self, + token_ids: torch.Tensor, + position_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + input_lens: torch.Tensor, + t: torch.Tensor, + p: torch.Tensor, + num_samples: int, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> torch.Tensor: + """Executes the forward pass of the model and samples the next token. + + Args: + token_ids: The input token IDs of shape [batch_size, seq_len]. + position_ids: The input position IDs of shape [batch_size, seq_len]. + attn_metadata: The Pallas attention metadata. + input_lens: The actual input lengths of shape [batch_size]. + t: The sampling temperature of shape [batch_size]. + p: The top-p probability of shape [batch_size]. + num_samples: Number of samples to draw from each logits vector. + kv_caches: The key and value caches. They can be None during the + memory profiling at initialization. + """ + batch_size, seq_len = token_ids.shape + # Calculate the positions to sample from. + start_indicies = torch.arange( + batch_size, dtype=torch.int32, device=input_lens.device) * seq_len + logits_indices = start_indicies + input_lens - 1 + + # FIXME(woosuk): This is a temporary hack to avoid using the existing + # sampler and sampling metadata. + sampling_metadata = SamplingMetadata( + seq_groups=[], + selected_token_indices=logits_indices, + categorized_sample_indices={}, + num_prompts=attn_metadata.num_prefills, + ) + + # Skip this in memory profiling at initialization. + if kv_caches[0][0].numel() > 0: + # index_copy_(slot_mapping) only works when the inserted dimension + # is 0. However, the KV cache in the Pallas backend has the shape + # [num_kv_heads, num_blocks, block_size, head_size]. To make it + # work, we need to flatten the first three dimensions and modify + # the slot_mapping accordingly. + num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape + slot_mapping = attn_metadata.slot_mapping + slot_mapping = slot_mapping.flatten() + head_indicies = torch.arange(0, + num_kv_heads, + device=slot_mapping.device, + dtype=slot_mapping.dtype) + head_indicies *= block_size * num_blocks + slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view( + -1, num_kv_heads) + slot_mapping = slot_mapping + head_indicies.view(1, -1) + slot_mapping = slot_mapping.flatten() + attn_metadata.slot_mapping = slot_mapping + + hidden_states = self.model( + token_ids, + position_ids, + kv_caches, + attn_metadata, + ) + hidden_states = hidden_states.flatten(0, 1) + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Argmax sampling. + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) + + # Zero temperature means greedy decoding. Avoid division by zero. + nonzero_t = torch.where(t != 0, t, 1.0) + logits = logits / nonzero_t.unsqueeze(dim=1) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + + # Random sampling. + probs = torch.softmax(logits, dim=-1, dtype=torch.float32) + sampled_token_ids = torch.multinomial(probs, + num_samples, + replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) + next_token_ids = torch.where(t != 0, sampled_token_ids, + argmax_token_ids) + return next_token_ids + + +def _get_padded_prefill_len(x: int) -> int: + # NOTE(woosuk): The pallas FlashAttention kernel requires the sequence + # length to be a multiple of 16. We pad the prompt length to the nearest + # multiple of 16. This is also good for performance. + if x <= 16: + return 16 + return 1 << (x - 1).bit_length() + + +def _get_padded_batch_size(batch_size: int) -> int: + # The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16. + # To meet this requirement in the simplest way, we set the minimal batch + # size to 8. + if batch_size <= 8: + return 8 + else: + return ((batch_size + 15) // 16) * 16 + + +def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor: + logits_sorted = torch.sort(logits, dim=-1, descending=True).values + sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1) + cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True) + cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index) + logits = logits.masked_fill_(logits < cutoff_logit, -float("inf")) + return logits + + +def _make_decode_output( + next_token_ids: List[int], + seq_groups: List[List[int]], +) -> SamplerOutput: + zero_logprob = Logprob(0.0) + sampler_outputs = [] + batch_idx = 0 + for seq_group in seq_groups: + seq_ids = seq_group + seq_outputs = [] + for seq_id in seq_ids: + next_token_id = next_token_ids[batch_idx] + seq_outputs.append( + SequenceOutput(seq_id, next_token_id, + {next_token_id: zero_logprob})) + batch_idx += 1 + sampler_outputs.append(CompletionSequenceGroupOutput( + seq_outputs, None)) + return SamplerOutput(sampler_outputs) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py new file mode 100644 index 0000000..fe819b9 --- /dev/null +++ b/vllm/worker/tpu_worker.py @@ -0,0 +1,309 @@ +import os +from typing import List, Optional, Tuple, Union + +import torch +import torch_xla.core.xla_model as xm +import torch_xla.runtime as xr + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.sequence import ExecuteModelRequest +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.worker.tpu_model_runner import TPUModelRunner +from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, + LoraNotSupportedWorkerBase, WorkerInput) + +logger = init_logger(__name__) + + +class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + + assert self.device_config.device_type == "tpu" + if self.cache_config.cache_dtype == "auto": + self.cache_dtype = self.model_config.dtype + else: + self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[ + self.cache_config.cache_dtype] + + self.model_runner: TPUModelRunner = TPUModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config, + is_driver_worker=is_driver_worker) + + def init_device(self) -> None: + os.environ["PJRT_DEVICE"] = "TPU" + torch.set_grad_enabled(False) + torch.set_default_dtype(self.model_config.dtype) + + # NOTE(woosuk): This is just to initialize the TP group and broadcast + # the input objects on CPU. The all-reduce and all-gather ops on TPU + # are invoked by `xm.all_reduce` and `xm.all_gather` which use their + # own context. + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size) + + # Device initialization should happen after initializing the distributed + # runtime. + self.device = xm.xla_device() + self.device_config.device = self.device + + # Set random seed. + set_random_seed(self.model_config.seed) + xm.set_rng_state(self.model_config.seed, self.device) + + # Increase the cache size limit, which is the maximum number of + # dynamo graphs that can be compiled. + # NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and + # 30-40 graphs for decode. 128 is an arbitrary safe number. + torch._dynamo.config.cache_size_limit = 128 + # Use persistent cache to avoid XLA recompilation. + # NOTE(woosuk): Set per-rank cache path since different ranks + # can have slightly different XLA graphs. + world_size = self.parallel_config.world_size + rank = xr.global_ordinal() + per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, + f"tp{world_size}_rank{rank}") + xr.initialize_cache(per_rank_path, readonly=False) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + num_layers = self.model_config.get_num_layers(self.parallel_config) + head_size = self.model_config.get_head_size() + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [(torch.tensor([], dtype=torch.float32, + device=self.device), + torch.tensor([], dtype=torch.float32, + device=self.device)) + for _ in range(num_layers)] + self.model_runner._dummy_run( + batch_size=1, + seq_len=self.scheduler_config.max_num_batched_tokens, + kv_caches=kv_caches, + is_prompt=True, + ) + # Synchronize before measuring the memory usage. + xm.wait_device_ops() + + dtype_btyes = get_dtype_size(self.cache_dtype) + block_size = self.cache_config.block_size + block_size_bytes = (dtype_btyes * block_size * num_layers * 2 * + head_size * num_kv_heads) + + # Calculate the TPU KV cache size based on profiling. + m = xm.get_memory_info(self.device) + total_memory_size = m["bytes_limit"] + usable_memory_size = int(total_memory_size * + self.cache_config.gpu_memory_utilization) + profiled = m["bytes_used"] # Weights + intermediate activations. + tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0) + num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes + num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8. + + # Calculate the CPU KV cache size based on the config. + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + block_size_bytes) + num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + return num_tpu_blocks, num_cpu_blocks + + def initialize_cache( + self, + num_gpu_blocks: int, + num_cpu_blocks: int, + ) -> None: + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + self.block_size = self.cache_config.block_size + + dtype = self.cache_dtype + num_layers = self.model_config.get_num_layers(self.parallel_config) + num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) + head_size = self.model_config.get_head_size() + + self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = [] + tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_gpu_blocks, self.block_size, num_kv_heads, head_size) + cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape( + num_cpu_blocks, self.block_size, num_kv_heads, head_size) + for _ in range(num_layers): + tpu_k_cache = torch.zeros(tpu_cache_shape, + dtype=dtype, + device=self.device) + tpu_v_cache = torch.zeros_like(tpu_k_cache) + self.tpu_cache.append((tpu_k_cache, tpu_v_cache)) + cpu_k_cache = torch.zeros(cpu_cache_shape, + dtype=dtype, + device="cpu") + cpu_v_cache = torch.zeros_like(cpu_k_cache) + self.cpu_cache.append((cpu_k_cache, cpu_v_cache)) + self._warmup_model() + + def _warmup_model(self) -> None: + # FIXME(woosuk): Here we are abusing `enforce_eager` which is defined + # for CUDA graphs. We should refactor this part. + if not self.model_config.enforce_eager: + # Warm up the model with all possible input shapes so that + # compilation never happens during the actual execution. + # This may take ~30 mins for the first run and ~20 mins for the + # subsequent runs. + # If `enforce_eager` is True, the ahead-of-time compilation is + # skipped and the compilation happens during the actual execution, + # which is bad for performance but useful for development. + self.model_runner.warmup_model(self.tpu_cache) + + def get_cache_block_size_bytes(self) -> int: + head_size = self.model_config.get_head_size() + num_heads = self.model_config.get_num_kv_heads(self.parallel_config) + num_layers = self.model_config.get_num_layers(self.parallel_config) + + key_cache_block = self.cache_config.block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + dtype_size = get_dtype_size(self.cache_dtype) + return dtype_size * total + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + # NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline + # parallelism. + return [self.tpu_cache] + + def prepare_worker_input( + self, + execute_model_req: ExecuteModelRequest, + ) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + blocks_to_swap_in = _make_src_to_dst( + execute_model_req.blocks_to_swap_in, "cpu", self.device) + blocks_to_swap_out = _make_src_to_dst( + execute_model_req.blocks_to_swap_out, self.device, "cpu") + blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy, + self.device, self.device) + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + ) + + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + assert virtual_engine == 0 + attn_backend = self.model_runner.attn_backend + num_layers = self.model_config.get_num_layers(self.parallel_config) + + # Issue cache operations. + if worker_input.blocks_to_swap_in is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_in + if src_indices.numel() > 0: + # Swap from CPU to TPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + k = cpu_k_cache[:, src_indices].to(self.device) + v = cpu_v_cache[:, src_indices].to(self.device) + _insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache) + + if worker_input.blocks_to_swap_out is not None: + src_indices, dst_indices = worker_input.blocks_to_swap_out + if src_indices.numel() > 0: + # Swap from TPU to CPU. + for i in range(num_layers): + tpu_k_cache, tpu_v_cache = self.tpu_cache[i] + cpu_k_cache, cpu_v_cache = self.cpu_cache[i] + cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices] + cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices] + + if worker_input.blocks_to_copy is not None: + src_indices, dst_indices = worker_input.blocks_to_copy + if src_indices.numel() > 0: + attn_backend.copy_blocks(self.tpu_cache, + (src_indices, dst_indices)) + + +def _make_src_to_dst( + mapping: List[Tuple[int, int]], + src_device: Union[torch.device, str], + dst_device: Union[torch.device, str], +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + if not mapping: + return None + + src_indices = [i for i, _ in mapping] + dst_indices = [i for _, i in mapping] + src_indices = torch.tensor(src_indices, + device=src_device, + dtype=torch.int64) + dst_indices = torch.tensor(dst_indices, + device=dst_device, + dtype=torch.int64) + return src_indices, dst_indices + + +@torch.compile(backend="openxla") +def _insert_kv( + k: torch.Tensor, + v: torch.Tensor, + indices: torch.Tensor, + tpu_k_cache: torch.Tensor, + tpu_v_cache: torch.Tensor, +) -> None: + torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True) + torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True) + tpu_k_cache[:, indices] = k + tpu_v_cache[:, indices] = v diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py new file mode 100644 index 0000000..f436354 --- /dev/null +++ b/vllm/worker/utils.py @@ -0,0 +1,51 @@ +''' +Worker-related helper functions. +''' + +from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS +from vllm.worker.model_runner import GPUModelRunnerBase + + +def assert_enc_dec_mr_supported_scenario( + enc_dec_mr: GPUModelRunnerBase) -> None: + ''' + Asserted that the provided encoder/decoder model runner instance reflects + a supported scenario. + ''' + + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + + if enc_dec_mr.cache_config.enable_prefix_caching: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE']) + + if enc_dec_mr.sliding_window is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SWA']) + + if enc_dec_mr.scheduler_config.chunked_prefill_enabled: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ + 'STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL']) + + if getattr(enc_dec_mr.model_config.hf_config, 'attn_logit_softcapping', + None) is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP'] + ) + + if enc_dec_mr.lora_config is not None: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA']) + + if enc_dec_mr.parallel_config.pipeline_parallel_size > 1: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) + + if enc_dec_mr.scheduler_config.num_lookahead_slots > 0: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_SPEC_DEC']) + + if enc_dec_mr.prompt_adapter_config is not None: + raise NotImplementedError(STR_NOT_IMPL_ENC_DEC_ERR_STRS[ + 'STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py new file mode 100644 index 0000000..ab61e43 --- /dev/null +++ b/vllm/worker/worker.py @@ -0,0 +1,497 @@ +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Optional, Set, Tuple, Type, Union + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.platforms import current_platform +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, + SequenceGroupMetadata, SequenceGroupMetadataDelta) +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.embedding_model_runner import EmbeddingModelRunner +from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner +from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput + +logger = init_logger(__name__) + + +class Worker(LocalOrDistributedWorkerBase): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, + observability_config: Optional[ObservabilityConfig] = None, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + self.observability_config = observability_config + + # Return hidden states from target model if the draft model is an + # mlp_speculator + speculative_args = {} if speculative_config is None \ + or (speculative_config.draft_model_config.model == + model_config.model) \ + or (speculative_config.draft_model_config.hf_config.model_type + not in ["medusa", "mlp_speculator", "eagle"]) \ + else {"return_hidden_states": True} + + ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner + if model_runner_cls is not None: + ModelRunnerClass = model_runner_cls + elif self._is_embedding_model(): + ModelRunnerClass = EmbeddingModelRunner + elif self._is_encoder_decoder_model(): + ModelRunnerClass = EncoderDecoderModelRunner + self.model_runner: GPUModelRunnerBase = ModelRunnerClass( + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + observability_config=observability_config, + **speculative_args, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + # Initialize gpu_cache as embedding models don't initialize kv_caches + self.gpu_cache: Optional[List[List[torch.Tensor]]] = None + self._seq_group_metadata_cache: Dict[str, SequenceGroupMetadata] = {} + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + def start_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.start() + + def stop_profile(self): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + self.profiler.stop() + + def _is_encoder_decoder_model(self): + return self.model_config.is_encoder_decoder_model + + def _is_embedding_model(self): + return self.model_config.is_embedding_model + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + gc.collect() + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_worker_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method, + self.local_rank) + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_runner.save_sharded_state( + path, + pattern=pattern, + max_size=max_size, + ) + + def save_tensorized_model( + self, + tensorizer_config: TensorizerConfig, + ) -> None: + self.model_runner.save_tensorized_model( + tensorizer_config=tensorizer_config, ) + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + if cache_block_size == 0: + num_gpu_blocks = 0 + num_cpu_blocks = 0 + else: + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Allocate GPU and CPU KV cache with the specified number of blocks. + + This also warms up the model, which may record CUDA graphs. + """ + raise_if_cache_size_invalid(num_gpu_blocks, + self.cache_config.block_size, + self.cache_config.is_attention_free, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._init_cache_engine() + self._warm_up_model() + + def _init_cache_engine(self): + assert self.cache_config.num_gpu_blocks is not None + self.cache_engine = [ + CacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.gpu_cache = [ + self.cache_engine[ve].gpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + + def _warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @property + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + return self.gpu_cache + + @torch.inference_mode() + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + virtual_engine = execute_model_req.virtual_engine + num_steps = execute_model_req.num_steps + num_seq_groups = len(execute_model_req.seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + + return WorkerInput( + num_seq_groups=num_seq_groups, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + virtual_engine=virtual_engine, + num_steps=num_steps, + ) + + @torch.inference_mode() + def execute_worker(self, worker_input: WorkerInput) -> None: + virtual_engine = worker_input.virtual_engine + # Issue cache operations. + if (worker_input.blocks_to_swap_in is not None + and worker_input.blocks_to_swap_in.numel() > 0): + self.cache_engine[virtual_engine].swap_in( + worker_input.blocks_to_swap_in) + if (worker_input.blocks_to_swap_out is not None + and worker_input.blocks_to_swap_out.numel() > 0): + self.cache_engine[virtual_engine].swap_out( + worker_input.blocks_to_swap_out) + if (worker_input.blocks_to_copy is not None + and worker_input.blocks_to_copy.numel() > 0): + self.cache_engine[virtual_engine].copy(worker_input.blocks_to_copy) + + def _get_cached_seq_group_metadata( + self, + seq_group_metadata_list: List[Union[SequenceGroupMetadata, + SequenceGroupMetadataDelta]], + finished_request_ids: List[str]) -> List[SequenceGroupMetadata]: + """Return a list of cached Sequence Group Metadata after updating its + state. + + It is used because scheduler only sends delta to workers to reduce + the data payload size. The function also cleans up cache based on + a given `finished_request_ids`. + """ + new_seq_group_metadata_list = [] + for metadata_or_delta in seq_group_metadata_list: + request_id = metadata_or_delta.request_id + if request_id not in self._seq_group_metadata_cache: + # The first prefill. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[request_id] = metadata_or_delta + else: + # The first prefill is already cached. + if isinstance(metadata_or_delta, SequenceGroupMetadataDelta): + self._seq_group_metadata_cache[request_id].apply_delta( + metadata_or_delta) + else: + # If metadata snapshot is sent again, it is + # preempted. Reset the cache because we need to start + # from scratch. + assert isinstance(metadata_or_delta, SequenceGroupMetadata) + self._seq_group_metadata_cache[ + request_id] = metadata_or_delta + + new_seq_group_metadata_list.append( + self._seq_group_metadata_cache[request_id]) + + # Clean up finished ids + for finished_id in finished_request_ids: + del self._seq_group_metadata_cache[finished_id] + + return new_seq_group_metadata_list + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Optional[List[SamplerOutput]]: + if execute_model_req is not None: + new_seq_group_metadata_list = self._get_cached_seq_group_metadata( + execute_model_req.seq_group_metadata_list, + execute_model_req.finished_requests_ids) + + execute_model_req.seq_group_metadata_list = ( + new_seq_group_metadata_list) + output = super()._execute_model_spmd(execute_model_req, + intermediate_tensors) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.model_runner.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() + + @property + def max_model_len(self) -> int: + return self.model_config.max_model_len + + @property + def vocab_size(self) -> int: + return self.model_runner.vocab_size + + def get_cache_block_size_bytes(self) -> int: + """Get the size of the KV cache block size in bytes. + """ + return CacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) + + +def init_worker_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + local_rank: int = -1, +) -> None: + """Initialize the distributed environment.""" + set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) + + init_distributed_environment(parallel_config.world_size, rank, + distributed_init_method, local_rank) + + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: # noqa: SIM102 + if not current_platform.has_device_capability(80): + capability = current_platform.get_device_capability() + gpu_name = current_platform.get_device_name() + + if capability is None: + compute_str = "does not have a compute capability" + else: + version_str = capability.as_version_str() + compute_str = f"has compute capability {version_str}" + + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") + + +def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, + max_model_len) -> None: + if is_attention_free and num_gpu_blocks != 0: + raise ValueError("No memory should be allocated for the cache blocks " + f"for an attention-free model, but {num_gpu_blocks}" + "blocks are allocated.") + if not is_attention_free and num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if not is_attention_free and max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py new file mode 100644 index 0000000..6ba4f27 --- /dev/null +++ b/vllm/worker/worker_base.py @@ -0,0 +1,485 @@ +import dataclasses +import importlib +import os +import time +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union + +import torch + +from vllm.config import ObservabilityConfig +from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform +from vllm.sequence import ExecuteModelRequest, IntermediateTensors +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) +from vllm.worker.model_runner_base import (BroadcastableModelInput, + ModelRunnerBase, + ModelRunnerInputBase) + +logger = init_logger(__name__) + + +class WorkerBase(ABC): + """Worker interface that allows vLLM to cleanly separate implementations for + different hardware. Also abstracts control plane communication, e.g., to + communicate request metadata to other workers. + """ + + @abstractmethod + def init_device(self) -> None: + """Initialize device state, such as loading the model or other on-device + memory allocations. + """ + raise NotImplementedError + + @abstractmethod + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available blocks for the GPU KV cache and + swappable CPU KV cache. + + The implementation may run profiling or other heuristics to determine + the size of caches. + + Returns a Tuple[num_gpu_blocks, num_cpu_blocks], where num_gpu_blocks + are blocks that are "active" on the device and can be appended to. + num_cpu_blocks refers to "swapped" blocks in CPU memory and cannot be + appended to. + """ + raise NotImplementedError + + @abstractmethod + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache with the given size in blocks. + """ + raise NotImplementedError + + @current_platform.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while True: + output = self.execute_model(execute_model_req=None) + if output is None: + return None + + @abstractmethod + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + raise NotImplementedError + + @abstractmethod + def get_cache_block_size_bytes(self) -> int: + """Return the size of a single cache block, in bytes. Used in + speculative decoding. + """ + raise NotImplementedError + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_loras(self) -> Set[int]: + raise NotImplementedError + + +class LoraNotSupportedWorkerBase(WorkerBase): + """Partial implementation of WorkerBase that raises exceptions when LoRA + methods are invoked. + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def remove_lora(self, lora_id: int) -> bool: + raise ValueError(f"{type(self)} does not support LoRA") + + def pin_lora(self, lora_id: int) -> bool: + return ValueError( + f"{type(self)} does not support LoRA") # type: ignore + + def list_loras(self) -> Set[int]: + raise ValueError(f"{type(self)} does not support LoRA") + + +@dataclasses.dataclass(frozen=True) +class WorkerInput: + """Local inputs to each worker. May contain device-specific data. These + fields should be broadcastable to other workers. + """ + + num_seq_groups: Optional[int] = None + blocks_to_swap_in: Optional[torch.Tensor] = None + blocks_to_swap_out: Optional[torch.Tensor] = None + blocks_to_copy: Optional[torch.Tensor] = None + virtual_engine: int = 0 + num_steps: int = 1 + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type["WorkerInput"], + tensor_dict: Dict[str, Any], + ) -> "WorkerInput": + """ + Pop fields from the given tensor_dict and populate a new instance of + WorkerInput. + """ + return cls( + num_seq_groups=tensor_dict.pop("num_seq_groups"), + blocks_to_swap_in=tensor_dict.pop("blocks_to_swap_in"), + blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), + blocks_to_copy=tensor_dict.pop("blocks_to_copy"), + virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), + ) + + def as_broadcastable_tensor_dict( + self) -> Dict[str, Union[int, torch.Tensor]]: + """ + Extract broadcastable fields. + """ + tensor_dict = { + "num_seq_groups": self.num_seq_groups, + "blocks_to_swap_in": self.blocks_to_swap_in, + "blocks_to_swap_out": self.blocks_to_swap_out, + "blocks_to_copy": self.blocks_to_copy, + "virtual_engine": self.virtual_engine, + "num_steps": self.num_steps, + } + + return tensor_dict + + +class LocalOrDistributedWorkerBase(WorkerBase): + """ + Partial implementation of WorkerBase that has a default `execute_model` + definition to perform metadata transfer between workers when in distributed + mode. Subclasses of this interface should use model runners that inherit + from ModelRunnerBase, and should only need to implement worker-local logic. + If custom control plane logic is needed to transfer metadata, or if the + model runner cannot inherit from ModelRunnerBase, use WorkerBase instead. + """ + is_driver_worker: bool + model_runner: ModelRunnerBase + observability_config: Optional[ObservabilityConfig] = None + + @property + @abstractmethod + def do_metadata_broadcast(self) -> bool: + """ + Used by the default `execute_model` to check whether broadcast is + needed to transfer request inputs from the driver worker to other + workers in the TP group. If WorkerBase subclass only supports + single-worker execution, then this method should return False. + """ + raise NotImplementedError + + @property + @abstractmethod + def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: + """ + Gets the list of kv caches to pass to the worker's model runner. Each + element in the list is a kv cache corresponding to a particular virtual + engine (PP stream). Used by the default `execute_model`. If the worker's + model runner does not follow the ModelRunnerBase interface, then inherit + from WorkerBase instead. + """ + raise NotImplementedError + + @abstractmethod + def prepare_worker_input( + self, execute_model_req: ExecuteModelRequest) -> WorkerInput: + """ + Prepare the inputs to WorkerBase.execute_worker from an execution + request. This method may move data to the worker's local device. It is + not allowed to communicate with other workers or devices. + """ + raise NotImplementedError + + @abstractmethod + def execute_worker(self, worker_input: WorkerInput) -> None: + """ + Process an execution request. + """ + raise NotImplementedError + + def _get_worker_input_from_broadcast( + self + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: + """ Get the worker input from the broadcasted tensor dict. """ + assert self.do_metadata_broadcast + assert not self.is_driver_worker + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) + + kwargs = extract_previous_hidden_states(broadcast_data) + + return model_input, worker_input, kwargs + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]: + """ Get the driver input and broadcast it to other workers. """ + assert self.is_driver_worker + + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + kwargs = extract_previous_hidden_states(execute_model_req) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_data.update(kwargs) + broadcast_tensor_dict(broadcast_data, src=0) + + if execute_model_req.async_callback: + model_input = dataclasses.replace( # type: ignore + model_input, + async_callback=execute_model_req.async_callback) + + return model_input, worker_input, kwargs + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[ + str, torch.Tensor]]]: + """ + Prepare the inputs to ModelRunner and workers. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + return self._get_driver_input_and_broadcast(execute_model_req) + else: + return self._get_worker_input_from_broadcast() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + start_time = time.perf_counter() + + inputs = self.prepare_input(execute_model_req) + if inputs is None: + return None + + model_input, worker_input, kwargs = inputs + num_steps = worker_input.num_steps + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + intermediate_tensors = None + orig_model_execute_time = 0.0 + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + orig_model_execute_time = intermediate_tensors.tensors.get( + "model_execute_time", torch.tensor(0)).item() + + output = self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + num_steps=num_steps, + **kwargs, + ) + + model_execute_time = time.perf_counter() - start_time + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time): + output.tensors["model_execute_time"] = torch.tensor( + model_execute_time + orig_model_execute_time) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return [None] + if (self.observability_config is not None + and self.observability_config.collect_model_execute_time + and output is not None): + for o in output: + o.model_execute_time = (orig_model_execute_time + + model_execute_time) + + # output is List[SamplerOutput] + return output + + def _execute_model_spmd( + self, + execute_model_req: ExecuteModelRequest, + intermediate_tensors: Optional[IntermediateTensors] = None + ) -> Optional[List[SamplerOutput]]: + """ + Execute model in Single Program Multiple Data (SPMD) fashion. + All workers take the same request, prepare the input and + execute the model. + """ + assert execute_model_req is not None, ( + "_execute_model_spmd() requires each worker to take in an " + "ExecuteModelRequest") + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list)) + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + kwargs = extract_previous_hidden_states(execute_model_req) + + return self.model_runner.execute_model( + model_input=model_input, + kv_caches=self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors=intermediate_tensors, + **kwargs, + ) + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + + If worker_class_fn is specified, it will be executed to get the worker + class. + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. + """ + + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False, + worker_class_fn: Optional[Callable[[], + Type[WorkerBase]]] = None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker_class_fn = worker_class_fn + self.worker: Optional[WorkerBase] = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + from vllm.plugins import load_general_plugins + load_general_plugins() + + if self.worker_class_fn: + worker_class = self.worker_class_fn() + else: + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + +def extract_previous_hidden_states( + data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \ + Dict[str, torch.Tensor]: + """If data contains previous_hidden_states, extract it. This returns a dict + which can be used directly as additional kwargs in any following + execute_model calls. This is used in draft models like EAGLE.""" + output = {} + + # When called from non-driver worker, data is dict but when called from + # driver worker, data is ExecuteModelRequest. + if isinstance(data, dict): + if "previous_hidden_states" in data: + output["previous_hidden_states"] = data["previous_hidden_states"] + elif data.previous_hidden_states is not None: + output["previous_hidden_states"] = data.previous_hidden_states\ + .hidden_states + + return output diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py new file mode 100644 index 0000000..5ff4626 --- /dev/null +++ b/vllm/worker/xpu_model_runner.py @@ -0,0 +1,605 @@ +import dataclasses +import time +import weakref +from dataclasses import dataclass +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, TypeVar) + +import torch +import torch.nn as nn + +from vllm.attention import get_attn_backend +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group +from vllm.inputs import INPUT_REGISTRY, InputRegistry +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadataCache +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import get_model +from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, + MultiModalInputs, MultiModalRegistry) +from vllm.sampling_params import SamplingParams +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata +from vllm.utils import DeviceMemoryProfiler, make_tensor_with_pad +from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata +from vllm.worker.model_runner_base import ( + ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, + _add_attn_metadata_broadcastable_dict, + _add_sampling_metadata_broadcastable_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + +_PAD_SLOT_ID = -1 +_BATCH_SIZE_ALIGNMENT = 8 +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) +] + +TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU") + + +@dataclass(frozen=True) +class ModelInputForXPU(ModelRunnerInputBase): + """ + Used by the NeuronModelRunner. + """ + input_tokens: Optional[torch.Tensor] = None + input_positions: Optional[torch.Tensor] = None + attn_metadata: Optional["AttentionMetadata"] = None + multi_modal_kwargs: Optional[BatchedTensorInputs] = None + virtual_engine: Optional[int] = None + seq_lens: Optional[List[int]] = None + query_lens: Optional[List[int]] = None + async_callback: Optional[Callable] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls: Type[TModelInputForXPU], + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> TModelInputForXPU: + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +@dataclass(frozen=True) +class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU): + """ + Used by the ModelRunner. + """ + sampling_metadata: Optional["SamplingMetadata"] = None + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + tensor_dict = { + "input_tokens": self.input_tokens, + "input_positions": self.input_positions, + } + _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) + _add_sampling_metadata_broadcastable_dict(tensor_dict, + self.sampling_metadata) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "ModelInputForXPUWithSamplingMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + return cls(**tensor_dict) + + +class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]): + + def __init__(self, + runner: "XPUModelRunner", + finished_requests_ids: Optional[List[str]] = None) -> None: + super().__init__() + self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] + self.runner = runner + self.model_input_cls = self.runner._model_input_cls + self.attn_backend = self.runner.attn_backend + self.sliding_window = self.runner.sliding_window + self.block_size = self.runner.block_size + self.device = self.runner.device + + def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): + self.seq_group_metadata_list.append(seq_group_metadata) + + def build(self) -> ModelInputForXPU: + is_prompt = self.seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) = self._prepare_prompt( + self.seq_group_metadata_list) + else: + (input_tokens, input_positions, + attn_metadata) = self._prepare_decode( + self.seq_group_metadata_list) + seq_lens = None + multi_modal_kwargs = None + + return self.model_input_cls( + input_tokens=input_tokens, + input_positions=input_positions, + attn_metadata=attn_metadata, + multi_modal_kwargs=multi_modal_kwargs, + seq_lens=seq_lens, + query_lens=seq_lens, + ) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], + BatchedTensorInputs]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + multi_modal_inputs_list: List[MultiModalInputs] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + computed_len = seq_data.get_num_computed_tokens() + seq_len = len(prompt_tokens) + + seq_lens.append(seq_len) # Prompt token num + input_tokens.extend(prompt_tokens) # Token ids + + # Token position ids + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.extend(list(range(computed_len, seq_len))) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.extend([_PAD_SLOT_ID] * seq_len) + continue + + # Compute the slot mapping. + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, seq_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + start_idx = max(0, seq_len - self.sliding_window) + + for i in range(computed_len, seq_len): + if i < start_idx: + slot_mapping.append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // + self.block_size] # type: ignore + block_offset = i % self.block_size # type: ignore + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + num_prompt_tokens = len(input_tokens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) # type: ignore + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) # type: ignore + + max_seqlen = max(seq_lens) + tmp = [0] + tmp.extend(seq_lens) + seqlen = torch.tensor(tmp) + seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=True, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seqlen_q=seqlen_q, + max_seqlen=max_seqlen, + seq_lens_tensor=torch.tensor([]), + max_decode_seq_len=0, + num_prefills=len(seq_lens), + num_prefill_tokens=num_prompt_tokens, + num_decode_tokens=0, + block_tables=torch.tensor([], device=self.device, dtype=torch.int), + ) + + multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) + + return (input_tokens, input_positions, attn_metadata, seq_lens, + multi_modal_kwargs) + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + seq_lens: List[int] = [] + block_tables: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + assert seq_group_metadata.token_chunk_size == 1 + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append(generation_token) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append(position) + + seq_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + seq_lens.append(seq_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append(slot) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + max_decode_seq_len = max(seq_lens) + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + input_positions = torch.tensor(input_positions, + dtype=torch.long, + device=self.device) + slot_mapping = torch.tensor(slot_mapping, + dtype=torch.long, + device=self.device) + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.int, + device=self.device) + + block_tables = make_tensor_with_pad( + block_tables, + pad=0, + dtype=torch.int, + device=self.device, + ) + + attn_metadata = self.attn_backend.make_metadata( + is_prompt=False, + slot_mapping=slot_mapping, + seq_lens=seq_lens, + seqlen_q=torch.tensor([]), + max_seqlen=0, + seq_lens_tensor=seq_lens_tensor, + max_decode_seq_len=max_decode_seq_len, + num_prefill_tokens=0, + num_decode_tokens=len(input_tokens), + num_prefills=0, + block_tables=block_tables, + ) + return ( + input_tokens, + input_positions, + attn_metadata, + ) + + +class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]): + _model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = ( + ModelInputForXPUWithSamplingMetadata) + _builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + return_hidden_states: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + input_registry: InputRegistry = INPUT_REGISTRY, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.lora_config = lora_config + self.load_config = load_config + self.is_driver_worker = is_driver_worker + self.prompt_adapter_config = prompt_adapter_config + self.observability_config = observability_config + if self.observability_config is not None: + print(f"observability_config is {self.observability_config}") + self.return_hidden_states = return_hidden_states + + self.device = self.device_config.device + + self.kv_cache_dtype = kv_cache_dtype + self.sliding_window = model_config.get_sliding_window() + self.block_size = cache_config.block_size + + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.get_sliding_window(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Multi-modal data support + self.input_registry = input_registry + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry \ + .create_input_mapper(model_config) + self.mm_registry.init_mm_limits_per_prompt(self.model_config) + + # Lazy initialization. + self.model: nn.Module # Set after init_Model + + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + + def load_model(self) -> None: + with DeviceMemoryProfiler() as m: + self.model = get_model( + model_config=self.model_config, + device_config=self.device_config, + load_config=self.load_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + cache_config=self.cache_config, + ) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + # Additional GPU memory may be needed for multi-modal encoding, which + # needs to be accounted for when calculating the GPU blocks for + # vLLM blocker manager. + # To exercise the worst scenario for GPU memory consumption, + # the number of seqs (batch_size) is chosen to maximize the number + # of images processed. + max_mm_tokens = self.mm_registry.get_max_multimodal_tokens( + self.model_config) + if max_mm_tokens > 0: + max_num_seqs_orig = max_num_seqs + max_num_seqs = min(max_num_seqs, + max_num_batched_tokens // max_mm_tokens) + if max_num_seqs < 1: + expr = (f"min({max_num_seqs_orig}, " + f"{max_num_batched_tokens} // {max_mm_tokens})") + logger.warning( + "Computed max_num_seqs (%s) to be less than 1. " + "Setting it to the minimum value of 1.", expr) + max_num_seqs = 1 + + batch_size = 0 + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len + + seq_data, dummy_multi_modal_data = self.input_registry \ + .dummy_data_for_profiling(self.model_config, + seq_len, + self.mm_registry) + + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=None, + multi_modal_data=dummy_multi_modal_data, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value ``None``. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + ] * num_layers + finished_requests_ids = [seq.request_id for seq in seqs] + model_input = self.prepare_model_input( + seqs, finished_requests_ids=finished_requests_ids) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) + torch.xpu.synchronize() + return + + def make_model_input_from_broadcasted_tensor_dict( + self, + tensor_dict: Dict[str, + Any]) -> ModelInputForXPUWithSamplingMetadata: + return ( + ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + )) + + def _prepare_model_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Helper method to prepare the model input based on a given sequence + group. Prepares metadata needed for the base model forward pass but not + metadata for possible additional steps, e.g., sampling. + + """ + builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) + for seq_group_metadata in seq_group_metadata_list: + builder.add_seq_group(seq_group_metadata) + + return builder.build() # type: ignore + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> ModelInputForXPUWithSamplingMetadata: + """Prepare the model input based on a given sequence group, including + metadata for the sampling step. + + """ + model_input = self._prepare_model_input_tensors( + seq_group_metadata_list, finished_requests_ids) + # Sampling metadata is only required for the final pp group + generators = self.get_generators(finished_requests_ids) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators, + cache=self.sampling_metadata_cache) + + return dataclasses.replace(model_input, + sampling_metadata=sampling_metadata, + virtual_engine=virtual_engine) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForXPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[List[SamplerOutput]]: + if num_steps > 1: + raise ValueError( + "XPUModelRunner does not support multi-step execution.") + + model_executable = self.model + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_start_time = time.time() + + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, + device=self.device)) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time): + model_forward_end_time = time.time() + + # Compute the logits. + logits = self.model.compute_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + # Only perform sampling in the driver worker. + if not self.is_driver_worker: + return [] + + if model_input.async_callback is not None: + model_input.async_callback() + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + if (self.observability_config is not None + and self.observability_config.collect_model_forward_time + and output is not None): + model_forward_time = (model_forward_end_time - + model_forward_start_time) + # If there are multiple workers, we are still tracking the latency + # from the start time of the driver worker to the end time of the + # driver worker. The model forward time will then end up covering + # the communication time as well. + output.model_forward_time = model_forward_time + + return [output] diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py new file mode 100644 index 0000000..9ad070d --- /dev/null +++ b/vllm/worker/xpu_worker.py @@ -0,0 +1,206 @@ +"""A XPU worker class.""" +import gc +import os +from typing import List, Optional, Tuple + +import intel_extension_for_pytorch # noqa: F401 +import oneccl_bindings_for_pytorch # noqa: F401 +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, + ModelConfig, ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.distributed.parallel_state import get_pp_group +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.utils import is_xpu +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.worker import Worker +from vllm.worker.worker_base import LoraNotSupportedWorkerBase +from vllm.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(LoraNotSupportedWorkerBase, Worker): + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single XPU device. The worker is + responsible for maintaining the KV cache and executing the model on the + XPU. In case of distributed inference, each worker is assigned a partition + of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + load_config: LoadConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, + is_driver_worker: bool = False, + observability_config: Optional[ObservabilityConfig] = None, + ) -> None: + assert device_config.device_type == "xpu" + assert is_xpu() + + self.model_config = model_config + self.parallel_config = parallel_config + self.parallel_config.rank = rank + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.load_config = load_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config + self.is_driver_worker = is_driver_worker + self.observability_config = observability_config + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + + self.model_runner = XPUModelRunner( # type: ignore + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=self.load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + observability_config=self.observability_config, + ) + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CacheEngine] + self.gpu_cache: Optional[List[List[torch.Tensor]]] + + def init_device(self) -> None: + if self.device_config.device.type == "xpu" and is_xpu(): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + self.init_worker_distributed_environment() + # Initialize the model. + set_random_seed(self.model_config.seed) + + # keep this method for `empty_cache` and `synchronize` api + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.xpu.synchronize() + used_memory = torch.xpu.memory_allocated() + total_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + free_gpu_memory = total_gpu_memory - used_memory + + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + assert peak_memory > 0, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + cache_block_size = self.get_cache_block_size_bytes() + num_gpu_blocks = int( + (total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) // cache_block_size) + num_cpu_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + gc.collect() + torch.xpu.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def _warm_up_model(self) -> None: + # IPEX don't support capture graph yet + pass + + def init_worker_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch " + "world size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + # use sockets as default Level zero IPC exchange backend. By + # default oneccl will use `drmfd` as mechanism which need extra + # dependency (libdrm and drm headers) on your system. + ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", + "sockets") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(parallel_config.world_size)) + os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=self.local_rank, + backend="ccl") + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + if parallel_config.pipeline_parallel_size > 1: + # torch-ccl xpu need a collective API warm up + # before calling send/recv API + get_pp_group().all_reduce(torch.zeros(1).xpu())