diff --git a/CMakeLists.txt b/CMakeLists.txt
new file mode 100644
index 0000000..44359c0
--- /dev/null
+++ b/CMakeLists.txt
@@ -0,0 +1,366 @@
+cmake_minimum_required(VERSION 3.21)
+
+project(vllm_extensions LANGUAGES CXX)
+
+option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "musa")
+
+message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
+message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
+
+include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
+
+#
+# Supported python versions. These versions will be searched in order, the
+# first match will be selected. These should be kept in sync with setup.py.
+#
+set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
+
+# Supported NVIDIA architectures.
+# set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
+
+# Supported MUSA architectures.
+set(MUSA_SUPPORTED_ARCHS "220")
+
+# Supported AMD GPU architectures.
+# set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
+
+#
+# Supported/expected torch versions for CUDA/ROCm.
+#
+# Currently, having an incorrect pytorch version results in a warning
+# rather than an error.
+#
+# Note: the CUDA torch version is derived from pyproject.toml and various
+# requirements.txt files and should be kept consistent. The ROCm torch
+# versions are derived from Dockerfile.rocm
+#
+set(TORCH_SUPPORTED_VERSION_CUDA "2.2.0")
+set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
+set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
+
+#
+# Try to find python package with an executable that exactly matches
+# `VLLM_PYTHON_EXECUTABLE` and is one of the supported versions.
+#
+if (VLLM_PYTHON_EXECUTABLE)
+ find_python_from_executable(${VLLM_PYTHON_EXECUTABLE} "${PYTHON_SUPPORTED_VERSIONS}")
+else()
+ message(FATAL_ERROR
+ "Please set VLLM_PYTHON_EXECUTABLE to the path of the desired python version"
+ " before running cmake configure.")
+endif()
+
+#
+# Update cmake's `CMAKE_PREFIX_PATH` with torch location.
+#
+append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
+
+include(/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/cmake/utils.cmake)
+
+add_definitions(-DTORCH_MUSA_ARCH=220)
+set(MUSA_CSRCS)
+set(CMAKE_MODULE_PATH /opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/cmake/modules)
+set(DEPENDENT_LIBRARIES "")
+set(DEPENDENT_INCLUDE_DIRS "")
+find_package(MUDNN)
+
+if(MUDNN_FOUND)
+ list(APPEND DEPENDENT_INCLUDE_DIRS ${MUDNN_INCLUDE_DIRS})
+ list(APPEND DEPENDENT_LIBRARIES ${MUDNN_LIBRARIES})
+else()
+ message(WARNING " The environment variable MUSA_HOME may be not specified."
+ "Using default MUDNN PATH: /usr/local/musa")
+
+ list(APPEND DEPENDENT_INCLUDE_DIRS "/usr/local/musa/include")
+ list(APPEND DEPENDENT_LIBRARIES "/usr/local/musa/lib/libmudnn.so")
+ set(MUDNN_PATH "/usr/local/musa")
+ set(MUDNN_LIBRARIES "/usr/local/musa/lib/libmudnn.so")
+endif()
+
+find_package(MUSAToolkits)
+
+if(MUSAToolkits_FOUND)
+ list(APPEND DEPENDENT_INCLUDE_DIRS ${MUSAToolkits_INCLUDE_DIRS})
+ list(APPEND DEPENDENT_LIBRARIES ${MUSAToolkits_LIBRARIES})
+else()
+ message(WARNING " The environment variable MUSA_HOME may be not specified."
+ "Using default MUSATOOLKITS PATH: /usr/local/musa")
+
+ list(APPEND DEPENDENT_INCLUDE_DIRS "/usr/local/musa/include/")
+ list(APPEND DEPENDENT_LIBRARIES "/usr/local/musa/lib/libmusart.so")
+ set(ENV{MUSA_HOME} "/usr/local/musa")
+ set(MUSATOOLKITS_PATH "/usr/local/musa")
+ set(MUSAToolkits_LIBRARIES "/usr/local/musa/lib/")
+endif()
+
+if(DEFINED PYTHON_INCLUDE_DIR)
+ include_directories(${PYTHON_INCLUDE_DIR})
+else()
+ message(FATAL_ERROR, "Cannot find installed Python head file directory")
+endif()
+
+list(APPEND CMAKE_MODULE_PATH $ENV{MUSA_HOME}/cmake)
+find_package(MUSA REQUIRED)
+
+#
+# Import torch cmake configuration.
+# Torch also imports CUDA (and partially HIP) languages with some customizations,
+# so there is no need to do this explicitly with check_language/enable_language,
+# etc.
+#
+find_package(Torch REQUIRED)
+
+#
+# Normally `torch.utils.cpp_extension.CUDAExtension` would add
+# `libtorch_python.so` for linking against an extension. Torch's cmake
+# configuration does not include this library (presumably since the cmake
+# config is used for standalone C++ binaries that link against torch).
+# The `libtorch_python.so` library defines some of the glue code between
+# torch/python via pybind and is required by VLLM extensions for this
+# reason. So, add it by manually with `find_library` using torch's
+# installed library path.
+#
+find_library(torch_python_LIBRARY torch_python PATHS
+ "${TORCH_INSTALL_PREFIX}/lib")
+
+#
+# Forward the non-CUDA device extensions to external CMake scripts.
+#
+if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
+ NOT VLLM_TARGET_DEVICE STREQUAL "musa" AND
+ NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
+ if (VLLM_TARGET_DEVICE STREQUAL "cpu")
+ include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
+ else()
+ message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
+ endif()
+ return()
+endif()
+
+#
+# Set up GPU language and check the torch version and warn if it isn't
+# what is expected.
+#
+if (NOT HIP_FOUND AND MUSA_FOUND)
+ set(VLLM_GPU_LANG "MUSA")
+
+ if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
+ message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
+ "expected for CUDA build, saw ${Torch_VERSION} instead.")
+ endif()
+elseif(HIP_FOUND)
+ set(VLLM_GPU_LANG "HIP")
+
+ # Importing torch recognizes and sets up some HIP/ROCm configuration but does
+ # not let cmake recognize .hip files. In order to get cmake to understand the
+ # .hip extension automatically, HIP must be enabled explicitly.
+ enable_language(HIP)
+
+ # ROCm 5.x
+ if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
+ NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
+ message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
+ "expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
+ endif()
+
+ # ROCm 6.x
+ if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
+ NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
+ message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
+ "expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
+ endif()
+else()
+ message(FATAL_ERROR "Can't find CUDA or HIP installation.")
+endif()
+
+#
+# Override the GPU architectures detected by cmake/torch and filter them by
+# the supported versions for the current language.
+# The final set of arches is stored in `VLLM_GPU_ARCHES`.
+#
+# override_gpu_arches(VLLM_GPU_ARCHES
+# ${VLLM_GPU_LANG}
+# "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
+
+#
+# Query torch for additional GPU compilation flags for the given
+# `VLLM_GPU_LANG`.
+# The final set of arches is stored in `VLLM_GPU_FLAGS`.
+#
+get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})
+
+#
+# Set nvcc parallelism.
+#
+if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
+ list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
+endif()
+
+#
+# Define extension targets
+#
+
+#
+# _C extension
+#
+
+set(VLLM_EXT_SRC
+ "csrc_musa/cache_kernels.mu"
+ "csrc_musa/attention/attention_kernels.mu"
+ "csrc_musa/pos_encoding_kernels.mu"
+ "csrc_musa/activation_kernels.mu"
+ "csrc_musa/layernorm_kernels.mu"
+ "csrc_musa/quantization/squeezellm/quant_cuda_kernel.mu"
+ "csrc_musa/quantization/gptq/q_gemm.mu"
+ "csrc_musa/quantization/fp8/fp8_cuda_kernels.mu"
+ "csrc_musa/musa_utils_kernels.mu"
+ "csrc_musa/moe_align_block_size_kernels.mu"
+ "csrc_musa/pybind.cpp")
+
+if(VLLM_GPU_LANG STREQUAL "MUSA")
+ list(APPEND VLLM_EXT_SRC
+ "csrc_musa/quantization/aqlm/gemm_kernels.mu"
+ "csrc_musa/quantization/awq/gemm_kernels.mu"
+ "csrc_musa/quantization/marlin/marlin_cuda_kernel.mu"
+ "csrc_musa/quantization/gptq_marlin/gptq_marlin.mu"
+ "csrc_musa/quantization/gptq_marlin/gptq_marlin_repack.mu"
+ "csrc_musa/custom_all_reduce.mu")
+endif()
+
+string(APPEND MUSA_MCC_FLAGS
+
+)
+string(APPEND MUSA_MCC_FLAGS " -U__CUDA__")
+
+set(MUSA_VERBOSE_BUILD ON)
+
+
+musa_include_directories(
+/opt/conda/envs/py39/include/python3.9
+/usr/local/musa/include
+/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/generated_cuda_compatible/aten/src
+/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/generated_cuda_compatible/include
+/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/generated_cuda_compatible/include/torch/csrc/api/include
+/opt/conda/envs/py39/lib/python3.9/site-packages
+/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa
+)
+
+musa_add_library(vllm_C SHARED ${VLLM_EXT_SRC})
+set(INSTALL_BIN_DIR "bin")
+set(INSTALL_LIB_DIR "lib64")
+set(INSTALL_INC_DIR "include")
+set(INSTALL_SHARE_DIR "share")
+set(INSTALL_DOC_DIR "docs")
+
+define_gpu_extension_target(
+ vllm_C
+ DESTINATION vllm
+ LANGUAGE ${VLLM_GPU_LANG}
+ SOURCES ${VLLM_EXT_SRC}
+ COMPILE_FLAGS ${VLLM_GPU_FLAGS}
+ ARCHITECTURES ${VLLM_GPU_ARCHES}
+ WITH_SOABI)
+
+target_link_libraries(vllm_C ${DEPENDENT_LIBRARIES})
+target_link_libraries(vllm_C "/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/lib/libmusa_python.so")
+#
+# _moe_C extension
+#
+
+set(VLLM_MOE_EXT_SRC
+ "csrc_musa/moe/moe_ops.cpp"
+ "csrc_musa/moe/topk_softmax_kernels.mu")
+
+define_gpu_extension_target(
+ _moe_C
+ DESTINATION vllm
+ LANGUAGE ${VLLM_GPU_LANG}
+ SOURCES ${VLLM_MOE_EXT_SRC}
+ COMPILE_FLAGS ${VLLM_GPU_FLAGS}
+ ARCHITECTURES ${VLLM_GPU_ARCHES}
+ WITH_SOABI)
+
+#
+# _punica_C extension
+#
+
+set(VLLM_PUNICA_EXT_SRC
+ "csrc_musa/punica/bgmv/bgmv_bf16_bf16_bf16.mu"
+ "csrc_musa/punica/bgmv/bgmv_bf16_fp32_bf16.mu"
+ "csrc_musa/punica/bgmv/bgmv_fp16_fp16_fp16.mu"
+ "csrc_musa/punica/bgmv/bgmv_fp16_fp32_fp16.mu"
+ "csrc_musa/punica/bgmv/bgmv_fp32_bf16_bf16.mu"
+ "csrc_musa/punica/bgmv/bgmv_fp32_fp16_fp16.mu"
+ "csrc_musa/punica/punica_ops.cc")
+
+#
+# Copy GPU compilation flags+update for punica
+#
+set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
+list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
+ "-D__MUSA_NO_HALF_OPERATORS__"
+ "-D__MUSA_NO_HALF_CONVERSIONS__"
+ "-D__MUSA_NO_BFLOAT16_CONVERSIONS__"
+ "-D__MUSA_NO_HALF2_OPERATORS__")
+
+#
+# Filter out CUDA architectures < 8.0 for punica.
+#
+# if (${VLLM_GPU_LANG} STREQUAL "CUDA")
+# set(VLLM_PUNICA_GPU_ARCHES)
+# foreach(ARCH ${VLLM_GPU_ARCHES})
+# string_to_ver(CODE_VER ${ARCH})
+# if (CODE_VER GREATER_EQUAL 8.0)
+# list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
+# endif()
+# endforeach()
+# message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
+# endif()
+
+if (VLLM_PUNICA_GPU_ARCHES)
+ define_gpu_extension_target(
+ _punica_C
+ DESTINATION vllm
+ LANGUAGE ${VLLM_GPU_LANG}
+ SOURCES ${VLLM_PUNICA_EXT_SRC}
+ COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
+ ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
+ WITH_SOABI)
+else()
+ message(WARNING "Unable to create _punica_C target because none of the "
+ "requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
+endif()
+
+#
+# Add the `default` target which detects which extensions should be
+# built based on platform/architecture. This is the same logic that
+# setup.py uses to select which extensions should be built and should
+# be kept in sync.
+#
+# The `default` target makes direct use of cmake easier since knowledge
+# of which extensions are supported has been factored in, e.g.
+#
+# mkdir build && cd build
+# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
+# cmake --build . --target default
+#
+add_custom_target(default)
+
+if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "MUSA" OR VLLM_GPU_LANG STREQUAL "HIP")
+ message(STATUS "Enabling C extension.")
+ add_dependencies(default _C)
+endif()
+
+if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "MUSA")
+ message(STATUS "Enabling moe extension.")
+ add_dependencies(default _moe_C)
+
+ # Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
+ # VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
+ # there are supported target arches.
+ if (VLLM_PUNICA_GPU_ARCHES AND
+ (ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
+ message(STATUS "Enabling punica extension.")
+ add_dependencies(default _punica_C)
+ endif()
+endif()
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 0000000..81a8db2
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,56 @@
+# Contributing to vLLM
+
+Thank you for your interest in contributing to vLLM!
+Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
+There are several ways you can contribute to the project:
+
+- Identify and report any issues or bugs.
+- Request or add a new model.
+- Suggest or implement new features.
+
+However, remember that contributions aren't just about code.
+We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
+
+Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
+Talk about it in your blog posts, highlighting how it's driving your incredible projects.
+Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
+
+
+## Setup for development
+
+### Build from source
+
+```bash
+pip install -e . # This may take several minutes.
+```
+
+### Testing
+
+```bash
+pip install -r requirements-dev.txt
+
+# linting and formatting
+bash format.sh
+# Static type checking
+mypy
+# Unit tests
+pytest tests/
+```
+**Note:** Currently, the repository does not pass the mypy tests.
+
+
+## Contributing Guidelines
+
+### Issue Reporting
+
+If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
+If not, please file a new issue, providing as much relevant information as possible.
+
+### Pull Requests & Code Reviews
+
+Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution.
+
+### Thank You
+
+Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
+Your contributions make vLLM a great tool for everyone!
diff --git a/Dockerfile b/Dockerfile
index 1c0a8e4..90be3a3 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1 +1,163 @@
-FROM registry.mthreads.com/mcconline/vllm-musa-qy2-py310:v0.8.4-release
\ No newline at end of file
+# The vLLM Dockerfile is used to construct vLLM image that can be directly used
+# to run the OpenAI compatible server.
+
+# Please update any changes made here to
+# docs/source/dev/dockerfile/dockerfile.rst and
+# docs/source/assets/dev/dockerfile-stages-dependency.png
+
+#################### BASE BUILD IMAGE ####################
+# prepare basic build environment
+FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
+
+RUN apt-get update -y \
+ && apt-get install -y python3-pip git
+
+# Workaround for https://github.com/openai/triton/issues/2507 and
+# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
+# this won't be needed for future versions of this docker image
+# or future versions of triton.
+RUN ldconfig /usr/local/cuda-12.4/compat/
+
+WORKDIR /workspace
+
+# install build and runtime dependencies
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cuda.txt requirements-cuda.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install -r requirements-cuda.txt
+
+# install development dependencies
+COPY requirements-dev.txt requirements-dev.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install -r requirements-dev.txt
+
+# cuda arch list used by torch
+# can be useful for both `dev` and `test`
+# explicitly set the list to avoid issues with torch 2.2
+# see https://github.com/pytorch/pytorch/pull/123243
+ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
+ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
+#################### BASE BUILD IMAGE ####################
+
+
+#################### WHEEL BUILD IMAGE ####################
+FROM dev AS build
+
+# install build dependencies
+COPY requirements-build.txt requirements-build.txt
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install -r requirements-build.txt
+
+# install compiler cache to speed up compilation leveraging local or remote caching
+RUN apt-get update -y && apt-get install -y ccache
+
+# files and directories related to build wheels
+COPY csrc csrc
+COPY setup.py setup.py
+COPY cmake cmake
+COPY CMakeLists.txt CMakeLists.txt
+COPY requirements-common.txt requirements-common.txt
+COPY requirements-cuda.txt requirements-cuda.txt
+COPY pyproject.toml pyproject.toml
+COPY vllm vllm
+
+# max jobs used by Ninja to build extensions
+ARG max_jobs=2
+ENV MAX_JOBS=${max_jobs}
+# number of threads used by nvcc
+ARG nvcc_threads=8
+ENV NVCC_THREADS=$nvcc_threads
+# make sure punica kernels are built (for LoRA)
+ENV VLLM_INSTALL_PUNICA_KERNELS=1
+
+ENV CCACHE_DIR=/root/.cache/ccache
+RUN --mount=type=cache,target=/root/.cache/ccache \
+ --mount=type=cache,target=/root/.cache/pip \
+ python3 setup.py bdist_wheel --dist-dir=dist
+
+# check the size of the wheel, we cannot upload wheels larger than 100MB
+COPY .buildkite/check-wheel-size.py check-wheel-size.py
+RUN python3 check-wheel-size.py dist
+
+# the `vllm_nccl` package must be installed from source distribution
+# pip is too smart to store a wheel in the cache, and other CI jobs
+# will directly use the wheel from the cache, which is not what we want.
+# we need to remove it manually
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip cache remove vllm_nccl*
+#################### EXTENSION Build IMAGE ####################
+
+#################### FLASH_ATTENTION Build IMAGE ####################
+FROM dev as flash-attn-builder
+# max jobs used for build
+ARG max_jobs=2
+ENV MAX_JOBS=${max_jobs}
+# flash attention version
+ARG flash_attn_version=v2.5.8
+ENV FLASH_ATTN_VERSION=${flash_attn_version}
+
+WORKDIR /usr/src/flash-attention-v2
+
+# Download the wheel or build it if a pre-compiled release doesn't exist
+RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
+ --no-build-isolation --no-deps --no-cache-dir
+
+#################### FLASH_ATTENTION Build IMAGE ####################
+
+#################### vLLM installation IMAGE ####################
+# image with vLLM installed
+FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
+WORKDIR /vllm-workspace
+
+RUN apt-get update -y \
+ && apt-get install -y python3-pip git vim
+
+# Workaround for https://github.com/openai/triton/issues/2507 and
+# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
+# this won't be needed for future versions of this docker image
+# or future versions of triton.
+RUN ldconfig /usr/local/cuda-12.4/compat/
+
+# install vllm wheel first, so that torch etc will be installed
+RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
+ --mount=type=cache,target=/root/.cache/pip \
+ pip install dist/*.whl --verbose
+
+RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
+ --mount=type=cache,target=/root/.cache/pip \
+ pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
+#################### vLLM installation IMAGE ####################
+
+
+#################### TEST IMAGE ####################
+# image to run unit testing suite
+# note that this uses vllm installed by `pip`
+FROM vllm-base AS test
+
+ADD . /vllm-workspace/
+
+# install development dependencies (for testing)
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install -r requirements-dev.txt
+
+# doc requires source code
+# we hide them inside `test_docs/` , so that this source code
+# will not be imported by other tests
+RUN mkdir test_docs
+RUN mv docs test_docs/
+RUN mv vllm test_docs/
+
+#################### TEST IMAGE ####################
+
+#################### OPENAI API SERVER ####################
+# openai api server alternative
+FROM vllm-base AS vllm-openai
+
+# install additional dependencies for openai api server
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install accelerate hf_transfer modelscope
+
+ENV VLLM_USAGE_SOURCE production-docker-image
+
+ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
+#################### OPENAI API SERVER ####################
diff --git a/Dockerfile.cpu b/Dockerfile.cpu
new file mode 100644
index 0000000..4251fdd
--- /dev/null
+++ b/Dockerfile.cpu
@@ -0,0 +1,20 @@
+# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
+
+FROM ubuntu:22.04
+
+RUN apt-get update -y \
+ && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
+ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
+
+RUN pip install --upgrade pip \
+ && pip install wheel packaging ninja setuptools>=49.4.0 numpy
+
+COPY ./ /workspace/vllm
+
+WORKDIR /workspace/vllm
+
+RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
+
+RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
+
+CMD ["/bin/bash"]
diff --git a/Dockerfile.neuron b/Dockerfile.neuron
new file mode 100644
index 0000000..fe42b4e
--- /dev/null
+++ b/Dockerfile.neuron
@@ -0,0 +1,36 @@
+# default base image
+ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
+
+FROM $BASE_IMAGE
+
+RUN echo "Base image is $BASE_IMAGE"
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+
+### Mount Point ###
+# When launching the container, mount the code directory to /app
+ARG APP_MOUNT=/app
+VOLUME [ ${APP_MOUNT} ]
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
+RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
+RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
+RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
+
+COPY ./vllm /app/vllm/vllm
+COPY ./setup.py /app/vllm/setup.py
+COPY ./requirements-common.txt /app/vllm/requirements-common.txt
+COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
+
+RUN cd /app/vllm \
+ && python3 -m pip install -U -r requirements-neuron.txt
+
+ENV VLLM_BUILD_WITH_NEURON 1
+RUN cd /app/vllm \
+ && pip install -e . \
+ && cd ..
+
+CMD ["/bin/bash"]
diff --git a/Dockerfile.rocm b/Dockerfile.rocm
new file mode 100644
index 0000000..d04bb99
--- /dev/null
+++ b/Dockerfile.rocm
@@ -0,0 +1,107 @@
+# default base image
+ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+FROM $BASE_IMAGE
+
+ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+RUN echo "Base image is $BASE_IMAGE"
+
+# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
+# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
+
+
+ARG FA_GFX_ARCHS="gfx90a;gfx942"
+RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
+
+ARG FA_BRANCH="ae7928c"
+RUN echo "FA_BRANCH is $FA_BRANCH"
+
+# whether to build flash-attention
+# if 0, will not build flash attention
+# this is useful for gfx target where flash-attention is not supported
+# In that case, we need to use the python reference attention implementation in vllm
+ARG BUILD_FA="1"
+
+# whether to build triton on rocm
+ARG BUILD_TRITON="1"
+
+# Install some basic utilities
+RUN apt-get update && apt-get install python3 python3-pip -y
+
+# Install some basic utilities
+RUN apt-get update && apt-get install -y \
+ curl \
+ ca-certificates \
+ sudo \
+ git \
+ bzip2 \
+ libx11-6 \
+ build-essential \
+ wget \
+ unzip \
+ nvidia-cuda-toolkit \
+ tmux \
+ && rm -rf /var/lib/apt/lists/*
+
+### Mount Point ###
+# When launching the container, mount the code directory to /app
+ARG APP_MOUNT=/vllm-workspace
+VOLUME [ ${APP_MOUNT} ]
+WORKDIR ${APP_MOUNT}
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
+
+ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
+ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
+ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
+ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
+
+# Install ROCm flash-attention
+RUN if [ "$BUILD_FA" = "1" ]; then \
+ mkdir libs \
+ && cd libs \
+ && git clone https://github.com/ROCm/flash-attention.git \
+ && cd flash-attention \
+ && git checkout ${FA_BRANCH} \
+ && git submodule update --init \
+ && export GPU_ARCHS=${FA_GFX_ARCHS} \
+ && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
+ patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
+ && python3 setup.py install \
+ && cd ..; \
+ fi
+
+# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
+# Manually removed it so that later steps of numpy upgrade can continue
+RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
+ rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
+
+# build triton
+RUN if [ "$BUILD_TRITON" = "1" ]; then \
+ mkdir -p libs \
+ && cd libs \
+ && pip uninstall -y triton \
+ && git clone https://github.com/ROCm/triton.git \
+ && cd triton/python \
+ && pip3 install . \
+ && cd ../..; \
+ fi
+
+WORKDIR /vllm-workspace
+COPY . .
+
+RUN python3 -m pip install --upgrade pip numba
+
+RUN --mount=type=cache,target=/root/.cache/pip \
+ pip install -U -r requirements-rocm.txt \
+ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
+ && python3 setup.py install \
+ && cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
+ && cd ..
+
+RUN python3 -m pip install --upgrade pip
+RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
+
+CMD ["/bin/bash"]
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..da695f2
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,846 @@
+The vllm_musa from Moore Threads is licensed under the Apache License 2.0 listed below.
+Copyright (c) 2022-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+Terms of the Apache License 2.0
+-------------------------------------------------------------------------
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+-------------------------------------------------------------------------
+The following copyright statements and licenses apply to various open source software/model
+packages (or portions thereof) that are distributed with this vllm_musa. vllm_musa that
+includes this file does not necessarily use all the open source software packages referred
+to below and may also only use portions of a given package. Some open source software
+packages referred to below may have been modified by Moore Threads Technology Co., Ltd
+
+-------------------------------------------------------------------------
+vllm
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------------------------------------------------------------------------------
+Contains code from https://github.com/punica-ai/punica
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------------------------------------------------------------------------------
+
+This product bundles various third-party components under other open source licenses.
+This section summarizes those components and their licenses. See licenses/
+for text of these licenses.
+
+
+Apache-2.0
+* third_party/nvbench (with LLVM exception)
+* third_party/flashinfer
+
+BSD-3-Clause:
+* third_party/cutlass
+
+------------------------------------------------------------------------------------
+Contains code from https://github.com/IST-DASLab/marlin
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "{}"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright {yyyy} {name of copyright owner}
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+------------------------------------------------------------------------------------
+
+This product bundles various third-party components under other open source licenses.
+This section summarizes those components and their licenses. See licenses/
+for text of these licenses.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..82be639
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,10 @@
+include LICENSE
+include requirements-common.txt
+include requirements-cuda.txt
+include requirements-rocm.txt
+include requirements-neuron.txt
+include requirements-cpu.txt
+include CMakeLists.txt
+
+recursive-include cmake *
+recursive-include csrc *
diff --git a/README.md b/README.md
index 8a67ab1..ab02cb4 100644
--- a/README.md
+++ b/README.md
@@ -8,4 +8,127 @@ vllm 版本:v0.8.4
源码地址:https://github.com/MooreThreads/vllm_musa
-原始镜像:registry.mthreads.com/mcconline/vllm-musa-qy2-py310:v0.8.4-release
\ No newline at end of file
+镜像:git.modelhub.org.cn:9443/enginex-mthreads/vllm-musa-qy2-py310:v0.8.4-release
+
+
+
+
+
+
+
+
+
+Easy, fast, and cheap LLM serving for everyone
+
+
+
+| Documentation | Blog | Paper | Discord |
+
+
+
+*Latest News* 🔥
+- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
+- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
+- [2024/01] Added ROCm 6.0 support to vLLM.
+- [2023/12] Added ROCm 5.7 support to vLLM.
+- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
+- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
+- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
+- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
+- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
+- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
+- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
+
+---
+## About
+vLLM is a fast and easy-to-use library for LLM inference and serving.
+
+vLLM is fast with:
+
+- State-of-the-art serving throughput
+- Efficient management of attention key and value memory with **PagedAttention**
+- Continuous batching of incoming requests
+- Fast model execution with CUDA/HIP graph
+- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
+- Optimized CUDA kernels
+
+vLLM is flexible and easy to use with:
+
+- Seamless integration with popular Hugging Face models
+- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
+- Tensor parallelism support for distributed inference
+- Streaming outputs
+- OpenAI-compatible API server
+- Support NVIDIA GPUs and AMD GPUs
+- (Experimental) Prefix caching support
+- (Experimental) Multi-lora support
+
+vLLM seamlessly supports many Hugging Face models, including the following architectures:
+
+- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
+- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
+- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
+- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
+- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
+- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
+- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
+- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
+- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
+- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
+- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
+- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
+- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
+- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
+- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
+- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
+- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
+- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
+- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
+- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
+- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
+- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.)
+- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
+- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
+- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
+- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
+- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
+- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
+- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
+- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
+- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
+- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
+- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
+
+Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
+
+```bash
+pip install vllm
+```
+
+## Getting Started
+
+Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
+- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
+- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
+- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
+
+## Contributing
+
+We welcome and value any contributions and collaborations.
+Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
+
+## Citation
+
+If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
+```bibtex
+@inproceedings{kwon2023efficient,
+ title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
+ author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
+ booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
+ year={2023}
+}
+```
+
+## vllm with MUSA
+
+Please refer to [README_vllm_musa](./README_vllm_musa.md).
diff --git a/README_vllm_musa.md b/README_vllm_musa.md
new file mode 100644
index 0000000..ad6e687
--- /dev/null
+++ b/README_vllm_musa.md
@@ -0,0 +1,66 @@
+# vllm_musa
+
+摩尔线程致力于构建完善好用的国产GPU应用生态,自主研发了MUSA架构及软件平台。vllm项目是业界广泛使用的大语言模型的推理和服务引擎,使用CUDA/ROCm提供GPU加速能力。为了方便摩尔线程GPU用户使用vllm框架,我们发起vllm_musa开源项目为vllm提供MUSA加速,让用户可释放摩尔线程GPU的澎湃算力。
+
+现有的vllm代码不支持摩尔线程GPU作为后端,因此我们新增了MUSA设备后端。vllm_musa接口与官方接口一致,用户无需改动业务代码,开箱即用。
+
+MUSA的一大优势是CUDA兼容,通过musify工具,我们可以快速将官方代码porting至MUSA软件栈,用户可以根据文档自行升级vllm版本并适配MUSA软件栈。
+
+## 依赖
+
+- musa_toolkit >= dev3.0.0
+- pytorch >= v2.2.0
+- [torch_musa](https://github.com/MooreThreads/torch_musa) >= v1.3.0
+- triton >= v2.2.0
+- ray >= 2.9
+- vllm v0.4.2
+
+## 使用
+### 编译
+运行 `bash build_musa.sh`
+### 测试示例
+```
+from vllm import LLM, SamplingParams
+from transformers import AutoTokenizer, LlamaForCausalLM
+import transformers
+import time
+import torch
+import torch_musa
+
+
+model_path =
+
+prompts = [
+ "Hello, my name is",
+ "The president of the United States is",
+ "The capital of France is",
+ "The future of AI is",
+]
+
+sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
+llm = LLM(model=model_path, trust_remote_code=True, device="musa")
+
+outputs = llm.generate(prompts, sampling_params)
+
+# Print the outputs.
+for output in outputs:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+
+```
+
+## Porting
+
+当前仓库porting自vllm v0.4.2版本。如果用户希望使用更高版本的vllm,只需要运行`musa_porting.py`将原生CUDA代码适配到MUSA代码即可。当然随着vllm的迭代可能会有些代码成为漏网之鱼,没有porting成功,用户可自行修改`musa_porting.py`文件中的文本替换规则。从而发挥MUSA强大的CUDA兼容能力。
+
+### 步骤
+1. 运行 `python musa_porting.py`
+2. 将`CMakeLists.txt`中需要编译的文件后缀从`.cu`修改为`.mu`
+3. 编译运行vllm_musa
+
+## 贡献
+
+欢迎广大用户及开发者使用、反馈,助力vllm_musa功能及性能持续完善。
+
+社区共建,期待广大开发者与我们一道,共同打造MUSA软件生态。我们将陆续推出一系列开源软件MUSA加速项目。
\ No newline at end of file
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 0000000..192d6c4
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,8 @@
+# Benchmarking vLLM
+
+## Downloading the ShareGPT dataset
+
+You can download the dataset by running:
+```bash
+wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
+```
diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py
new file mode 100644
index 0000000..f9d1675
--- /dev/null
+++ b/benchmarks/backend_request_func.py
@@ -0,0 +1,389 @@
+import json
+import os
+import sys
+import time
+import traceback
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import aiohttp
+from tqdm.asyncio import tqdm
+
+AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
+
+
+@dataclass
+class RequestFuncInput:
+ prompt: str
+ api_url: str
+ prompt_len: int
+ output_len: int
+ model: str
+ best_of: int = 1
+ use_beam_search: bool = False
+
+
+@dataclass
+class RequestFuncOutput:
+ generated_text: str = ""
+ success: bool = False
+ latency: float = 0.0
+ ttft: float = 0.0 # Time to first token
+ itl: List[float] = field(
+ default_factory=list) # List of inter-token latencies
+ prompt_len: int = 0
+ error: str = ""
+
+
+async def async_request_tgi(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ assert api_url.endswith("generate_stream")
+
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
+ assert not request_func_input.use_beam_search
+ params = {
+ "best_of": request_func_input.best_of,
+ "max_new_tokens": request_func_input.output_len,
+ "do_sample": True,
+ "temperature": 0.01, # TGI does not accept 0.0 temperature.
+ "top_p": 0.99, # TGI does not accept 1.0 top_p.
+ }
+ payload = {
+ "inputs": request_func_input.prompt,
+ "parameters": params,
+ }
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data:")
+
+ data = json.loads(chunk)
+ timestamp = time.perf_counter()
+ # First token
+ if ttft == 0.0:
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+
+ output.latency = most_recent_timestamp - st
+ output.success = True
+ output.generated_text = data["generated_text"]
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_trt_llm(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ assert api_url.endswith("generate_stream")
+
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
+ assert not request_func_input.use_beam_search
+ assert request_func_input.best_of == 1
+ payload = {
+ "accumulate_tokens": True,
+ "text_input": request_func_input.prompt,
+ "temperature": 0.0,
+ "top_p": 1.0,
+ "max_tokens": request_func_input.output_len,
+ "stream": True,
+ }
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data:")
+
+ data = json.loads(chunk)
+ output.generated_text += data["text_output"]
+ timestamp = time.perf_counter()
+ # First token
+ if ttft == 0.0:
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+
+ output.latency = most_recent_timestamp - st
+ output.success = True
+
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_deepspeed_mii(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
+ assert request_func_input.best_of == 1
+ assert not request_func_input.use_beam_search
+
+ payload = {
+ "prompt": request_func_input.prompt,
+ "max_tokens": request_func_input.output_len,
+ "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
+ "top_p": 1.0,
+ }
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ # NOTE: DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
+ # will use 0 as placeholder.
+ # See https://github.com/microsoft/DeepSpeed-MII/pull/311
+ output.ttft = 0
+
+ st = time.perf_counter()
+ try:
+ async with session.post(url=request_func_input.api_url,
+ json=payload) as response:
+ if response.status == 200:
+ parsed_resp = await response.json()
+ output.latency = time.perf_counter() - st
+ output.generated_text = parsed_resp["text"][0]
+ output.success = True
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_openai_completions(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ assert api_url.endswith(
+ "v1/completions"
+ ), "OpenAI Completions API URL must end with 'v1/completions'."
+
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
+ assert not request_func_input.use_beam_search
+ payload = {
+ "model": request_func_input.model,
+ "prompt": request_func_input.prompt,
+ "temperature": 0.0,
+ "best_of": request_func_input.best_of,
+ "max_tokens": request_func_input.output_len,
+ "stream": True,
+ }
+ headers = {
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
+ }
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload,
+ headers=headers) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data: ")
+ if chunk == "[DONE]":
+ latency = time.perf_counter() - st
+ else:
+ data = json.loads(chunk)
+
+ if data["choices"][0]["text"]:
+ timestamp = time.perf_counter()
+ # First token
+ if ttft == 0.0:
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ # NOTE: Some completion API might have a last
+ # usage summary response without a token so we
+ # do not want to include as inter-token-latency
+ elif data.get("usage", None) is None:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ most_recent_timestamp = timestamp
+ generated_text += data["choices"][0]["text"]
+
+ output.generated_text = generated_text
+ output.success = True
+ output.latency = latency
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+async def async_request_openai_chat_completions(
+ request_func_input: RequestFuncInput,
+ pbar: Optional[tqdm] = None,
+) -> RequestFuncOutput:
+ api_url = request_func_input.api_url
+ assert api_url.endswith(
+ "v1/chat/completions"
+ ), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
+
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
+ assert not request_func_input.use_beam_search
+ payload = {
+ "model": request_func_input.model,
+ "messages": [
+ {
+ "role": "user",
+ "content": request_func_input.prompt,
+ },
+ ],
+ "temperature": 0.0,
+ "max_tokens": request_func_input.output_len,
+ "stream": True,
+ }
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
+ }
+
+ output = RequestFuncOutput()
+ output.prompt_len = request_func_input.prompt_len
+
+ generated_text = ""
+ ttft = 0.0
+ st = time.perf_counter()
+ most_recent_timestamp = st
+ try:
+ async with session.post(url=api_url, json=payload,
+ headers=headers) as response:
+ if response.status == 200:
+ async for chunk_bytes in response.content:
+ chunk_bytes = chunk_bytes.strip()
+ if not chunk_bytes:
+ continue
+
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"),
+ "data: ")
+ if chunk == "[DONE]":
+ latency = time.perf_counter() - st
+ else:
+ timestamp = time.perf_counter()
+ data = json.loads(chunk)
+
+ delta = data["choices"][0]["delta"]
+ if delta.get("content", None):
+ # First token
+ if ttft == 0.0:
+ ttft = time.perf_counter() - st
+ output.ttft = ttft
+
+ # Decoding phase
+ else:
+ output.itl.append(timestamp -
+ most_recent_timestamp)
+
+ generated_text += delta["content"]
+
+ most_recent_timestamp = timestamp
+
+ output.generated_text = generated_text
+ output.success = True
+ output.latency = latency
+ else:
+ output.error = response.reason or ""
+ output.success = False
+ except Exception:
+ output.success = False
+ exc_info = sys.exc_info()
+ output.error = "".join(traceback.format_exception(*exc_info))
+
+ if pbar:
+ pbar.update(1)
+ return output
+
+
+# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
+# introduced in Python 3.9
+def remove_prefix(text: str, prefix: str) -> str:
+ if text.startswith(prefix):
+ return text[len(prefix):]
+ return text
+
+
+ASYNC_REQUEST_FUNCS = {
+ "tgi": async_request_tgi,
+ "vllm": async_request_openai_completions,
+ "lmdeploy": async_request_openai_completions,
+ "deepspeed-mii": async_request_deepspeed_mii,
+ "openai": async_request_openai_completions,
+ "openai-chat": async_request_openai_chat_completions,
+ "tensorrt-llm": async_request_trt_llm,
+}
diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py
new file mode 100644
index 0000000..44da3ba
--- /dev/null
+++ b/benchmarks/benchmark_latency.py
@@ -0,0 +1,195 @@
+"""Benchmark the latency of processing a single batch of requests."""
+import argparse
+import time
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from vllm import LLM, SamplingParams
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+
+
+def main(args: argparse.Namespace):
+ print(args)
+
+ # NOTE(woosuk): If the request cannot be processed in a single batch,
+ # the engine will automatically process the request in multiple batches.
+ llm = LLM(model=args.model,
+ tokenizer=args.tokenizer,
+ quantization=args.quantization,
+ tensor_parallel_size=args.tensor_parallel_size,
+ trust_remote_code=args.trust_remote_code,
+ dtype=args.dtype,
+ enforce_eager=args.enforce_eager,
+ kv_cache_dtype=args.kv_cache_dtype,
+ quantization_param_path=args.quantization_param_path,
+ device=args.device,
+ ray_workers_use_nsight=args.ray_workers_use_nsight,
+ enable_chunked_prefill=args.enable_chunked_prefill,
+ download_dir=args.download_dir,
+ block_size=args.block_size)
+
+ sampling_params = SamplingParams(
+ n=args.n,
+ temperature=0.0 if args.use_beam_search else 1.0,
+ top_p=1.0,
+ use_beam_search=args.use_beam_search,
+ ignore_eos=True,
+ max_tokens=args.output_len,
+ )
+ print(sampling_params)
+ dummy_prompt_token_ids = np.random.randint(10000,
+ size=(args.batch_size,
+ args.input_len))
+ dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
+
+ def run_to_completion(profile_dir: Optional[str] = None):
+ if profile_dir:
+ with torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(
+ str(profile_dir))) as p:
+ llm.generate(prompt_token_ids=dummy_prompt_token_ids,
+ sampling_params=sampling_params,
+ use_tqdm=False)
+ print(p.key_averages())
+ else:
+ start_time = time.perf_counter()
+ llm.generate(prompt_token_ids=dummy_prompt_token_ids,
+ sampling_params=sampling_params,
+ use_tqdm=False)
+ end_time = time.perf_counter()
+ latency = end_time - start_time
+ return latency
+
+ print("Warming up...")
+ for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
+ run_to_completion(profile_dir=None)
+
+ if args.profile:
+ profile_dir = args.profile_result_dir
+ if not profile_dir:
+ profile_dir = Path(
+ "."
+ ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
+ print(f"Profiling (results will be saved to '{profile_dir}')...")
+ run_to_completion(profile_dir=profile_dir)
+ return
+
+ # Benchmark.
+ latencies = []
+ for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
+ latencies.append(run_to_completion(profile_dir=None))
+ latencies = np.array(latencies)
+ percentages = [10, 25, 50, 75, 90]
+ percentiles = np.percentile(latencies, percentages)
+ print(f'Avg latency: {np.mean(latencies)} seconds')
+ for percentage, percentile in zip(percentages, percentiles):
+ print(f'{percentage}% percentile latency: {percentile} seconds')
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description='Benchmark the latency of processing a single batch of '
+ 'requests till completion.')
+ parser.add_argument('--model', type=str, default='facebook/opt-125m')
+ parser.add_argument('--tokenizer', type=str, default=None)
+ parser.add_argument('--quantization',
+ '-q',
+ choices=[*QUANTIZATION_METHODS, None],
+ default=None)
+ parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
+ parser.add_argument('--input-len', type=int, default=32)
+ parser.add_argument('--output-len', type=int, default=128)
+ parser.add_argument('--batch-size', type=int, default=8)
+ parser.add_argument('--n',
+ type=int,
+ default=1,
+ help='Number of generated sequences per prompt.')
+ parser.add_argument('--use-beam-search', action='store_true')
+ parser.add_argument('--num-iters-warmup',
+ type=int,
+ default=10,
+ help='Number of iterations to run for warmup.')
+ parser.add_argument('--num-iters',
+ type=int,
+ default=30,
+ help='Number of iterations to run.')
+ parser.add_argument('--trust-remote-code',
+ action='store_true',
+ help='trust remote code from huggingface')
+ parser.add_argument(
+ '--dtype',
+ type=str,
+ default='auto',
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
+ help='data type for model weights and activations. '
+ 'The "auto" option will use FP16 precision '
+ 'for FP32 and FP16 models, and BF16 precision '
+ 'for BF16 models.')
+ parser.add_argument('--enforce-eager',
+ action='store_true',
+ help='enforce eager mode and disable CUDA graph')
+ parser.add_argument(
+ "--kv-cache-dtype",
+ type=str,
+ choices=['auto', 'fp8'],
+ default='auto',
+ help=
+ 'Data type for kv cache storage. If "auto", will use model data type. '
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+ 'common inference criteria.')
+ parser.add_argument(
+ '--quantization-param-path',
+ type=str,
+ default=None,
+ help='Path to the JSON file containing the KV cache scaling factors. '
+ 'This should generally be supplied, when KV cache dtype is FP8. '
+ 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
+ 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
+ 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
+ 'instead supported for common inference criteria.')
+ parser.add_argument(
+ '--profile',
+ action='store_true',
+ help='profile the generation process of a single batch')
+ parser.add_argument(
+ '--profile-result-dir',
+ type=str,
+ default=None,
+ help=('path to save the pytorch profiler output. Can be visualized '
+ 'with ui.perfetto.dev or Tensorboard.'))
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ choices=["cuda", "cpu"],
+ help='device type for vLLM execution, supporting CUDA and CPU.')
+ parser.add_argument('--block-size',
+ type=int,
+ default=16,
+ help='block size of key/value cache')
+ parser.add_argument(
+ '--enable-chunked-prefill',
+ action='store_true',
+ help='If True, the prefill requests can be chunked based on the '
+ 'max_num_batched_tokens')
+ parser.add_argument(
+ "--ray-workers-use-nsight",
+ action='store_true',
+ help="If specified, use nsight to profile ray workers",
+ )
+ parser.add_argument('--download-dir',
+ type=str,
+ default=None,
+ help='directory to download and load the weights, '
+ 'default to the default cache dir of huggingface')
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py
new file mode 100644
index 0000000..0899669
--- /dev/null
+++ b/benchmarks/benchmark_prefix_caching.py
@@ -0,0 +1,62 @@
+import argparse
+import time
+
+from vllm import LLM, SamplingParams
+
+PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
+
+
+def test_prefix(llm=None, sampling_params=None, prompts=None):
+ start_time = time.time()
+
+ llm.generate(prompts, sampling_params=sampling_params)
+
+ end_time = time.time()
+ print(f"cost time {end_time - start_time}")
+
+
+def main(args):
+ llm = LLM(model=args.model,
+ tokenizer_mode='auto',
+ trust_remote_code=True,
+ enforce_eager=True,
+ use_v2_block_manager=args.use_v2_block_manager,
+ tensor_parallel_size=args.tensor_parallel_size,
+ enable_prefix_caching=args.enable_prefix_caching)
+
+ num_prompts = 100
+ prompts = [PROMPT] * num_prompts
+ sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
+
+ print("------warm up------")
+ test_prefix(
+ llm=llm,
+ prompts=prompts,
+ sampling_params=sampling_params,
+ )
+
+ print("------start generating------")
+ test_prefix(
+ llm=llm,
+ prompts=prompts,
+ sampling_params=sampling_params,
+ )
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description='Benchmark the performance with or without automatic '
+ 'prefix caching.')
+ parser.add_argument('--model',
+ type=str,
+ default='baichuan-inc/Baichuan2-13B-Chat')
+ parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
+ parser.add_argument('--output-len', type=int, default=10)
+ parser.add_argument('--enable-prefix-caching',
+ action='store_true',
+ help='enable prefix caching')
+ parser.add_argument('--use-v2-block-manager',
+ action='store_true',
+ help='Use BlockSpaceMangerV2')
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py
new file mode 100644
index 0000000..2c2d69d
--- /dev/null
+++ b/benchmarks/benchmark_serving.py
@@ -0,0 +1,596 @@
+"""Benchmark online serving throughput.
+
+On the server side, run one of the following commands:
+ vLLM OpenAI API server
+ python -m vllm.entrypoints.openai.api_server \
+ --model --swap-space 16 \
+ --disable-log-requests
+
+ (TGI backend)
+ ./launch_tgi_server.sh
+
+On the client side, run:
+ python benchmarks/benchmark_serving.py \
+ --backend \
+ --model \
+ --dataset-name sharegpt \
+ --dataset-path \
+ --request-rate \ # By default is inf
+ --num-prompts # By default is 1000
+"""
+import argparse
+import asyncio
+import json
+import os
+import random
+import time
+import warnings
+from dataclasses import dataclass
+from datetime import datetime
+from typing import AsyncGenerator, List, Optional, Tuple
+
+import numpy as np
+from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
+ RequestFuncOutput)
+from tqdm.asyncio import tqdm
+from transformers import PreTrainedTokenizerBase
+
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+
+@dataclass
+class BenchmarkMetrics:
+ completed: int
+ total_input: int
+ total_output: int
+ request_throughput: float
+ input_throughput: float
+ output_throughput: float
+ mean_ttft_ms: float
+ median_ttft_ms: float
+ p99_ttft_ms: float
+ mean_tpot_ms: float
+ median_tpot_ms: float
+ p99_tpot_ms: float
+
+
+def sample_sharegpt_requests(
+ dataset_path: str,
+ num_requests: int,
+ tokenizer: PreTrainedTokenizerBase,
+ fixed_output_len: Optional[int] = None,
+) -> List[Tuple[str, int, int]]:
+ if fixed_output_len is not None and fixed_output_len < 4:
+ raise ValueError("output_len too small")
+
+ # Load the dataset.
+ with open(dataset_path) as f:
+ dataset = json.load(f)
+ # Filter out the conversations with less than 2 turns.
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
+ # Only keep the first two turns of each conversation.
+ dataset = [(data["conversations"][0]["value"],
+ data["conversations"][1]["value"]) for data in dataset]
+
+ # Shuffle the dataset.
+ random.shuffle(dataset)
+
+ # Filter out sequences that are too long or too short
+ filtered_dataset: List[Tuple[str, int, int]] = []
+ for i in range(len(dataset)):
+ if len(filtered_dataset) == num_requests:
+ break
+
+ # Tokenize the prompts and completions.
+ prompt = dataset[i][0]
+ prompt_token_ids = tokenizer(prompt).input_ids
+ completion = dataset[i][1]
+ completion_token_ids = tokenizer(completion).input_ids
+ prompt_len = len(prompt_token_ids)
+ output_len = len(completion_token_ids
+ ) if fixed_output_len is None else fixed_output_len
+ if prompt_len < 4 or output_len < 4:
+ # Prune too short sequences.
+ continue
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
+ # Prune too long sequences.
+ continue
+ filtered_dataset.append((prompt, prompt_len, output_len))
+
+ return filtered_dataset
+
+
+def sample_sonnet_requests(
+ dataset_path: str,
+ num_requests: int,
+ input_len: int,
+ output_len: int,
+ prefix_len: int,
+ tokenizer: PreTrainedTokenizerBase,
+) -> List[Tuple[str, str, int, int]]:
+ assert (
+ input_len > prefix_len
+ ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
+
+ # Load the dataset.
+ with open(dataset_path) as f:
+ poem_lines = f.readlines()
+
+ # Tokenize the poem lines.
+ poem_token_ids = tokenizer(poem_lines).input_ids
+ average_poem_len = sum(
+ len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
+
+ # Base prefix for all requests.
+ base_prompt = "Pick as many lines as you can from these poem lines:\n"
+ base_message = [{
+ "role": "user",
+ "content": base_prompt,
+ }]
+ base_prompt_formatted = tokenizer.apply_chat_template(
+ base_message, add_generation_prompt=True, tokenize=False)
+ base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
+
+ assert (
+ input_len > base_prompt_offset
+ ), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
+ num_input_lines = round(
+ (input_len - base_prompt_offset) / average_poem_len)
+
+ # First approximately `prefix_len` number of tokens in the
+ # prompt are fixed poem lines.
+ assert (
+ prefix_len > base_prompt_offset
+ ), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
+
+ num_prefix_lines = round(
+ (prefix_len - base_prompt_offset) / average_poem_len)
+ prefix_lines = poem_lines[:num_prefix_lines]
+
+ # Sample the rest of lines per request.
+ sampled_requests: List[Tuple[str, int, int]] = []
+ for _ in range(num_requests):
+ sampled_lines = "".join(
+ prefix_lines +
+ random.sample(poem_lines, num_input_lines - num_prefix_lines))
+
+ prompt = f"{base_prompt}{sampled_lines}"
+ message = [
+ {
+ "role": "user",
+ "content": prompt,
+ },
+ ]
+ prompt_formatted = tokenizer.apply_chat_template(
+ message, add_generation_prompt=True, tokenize=False)
+ prompt_len = len(tokenizer(prompt_formatted).input_ids)
+ sampled_requests.append(
+ (prompt, prompt_formatted, prompt_len, output_len))
+
+ return sampled_requests
+
+
+async def get_request(
+ input_requests: List[Tuple[str, int, int]],
+ request_rate: float,
+) -> AsyncGenerator[Tuple[str, int, int], None]:
+ input_requests = iter(input_requests)
+ for request in input_requests:
+ yield request
+
+ if request_rate == float("inf"):
+ # If the request rate is infinity, then we don't need to wait.
+ continue
+ # Sample the request interval from the exponential distribution.
+ interval = np.random.exponential(1.0 / request_rate)
+ # The next request will be sent after the interval.
+ await asyncio.sleep(interval)
+
+
+def calculate_metrics(
+ input_requests: List[Tuple[str, int, int]],
+ outputs: List[RequestFuncOutput],
+ dur_s: float,
+ tokenizer: PreTrainedTokenizerBase,
+) -> Tuple[BenchmarkMetrics, List[int]]:
+ actual_output_lens = []
+ total_input = 0
+ completed = 0
+ tpots = []
+ ttfts = []
+ for i in range(len(outputs)):
+ if outputs[i].success:
+ output_len = len(tokenizer(outputs[i].generated_text).input_ids)
+ actual_output_lens.append(output_len)
+ total_input += input_requests[i][1]
+ if output_len > 1:
+ tpots.append(
+ (outputs[i].latency - outputs[i].ttft) / (output_len - 1))
+ ttfts.append(outputs[i].ttft)
+ completed += 1
+ else:
+ actual_output_lens.append(0)
+
+ metrics = BenchmarkMetrics(
+ completed=completed,
+ total_input=total_input,
+ total_output=sum(actual_output_lens),
+ request_throughput=completed / dur_s,
+ input_throughput=total_input / dur_s,
+ output_throughput=sum(actual_output_lens) / dur_s,
+ mean_ttft_ms=np.mean(ttfts or 0) *
+ 1000, # ttfts is empty if streaming is not supported by backend
+ median_ttft_ms=np.median(ttfts or 0) * 1000,
+ p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
+ mean_tpot_ms=np.mean(tpots) * 1000,
+ median_tpot_ms=np.median(tpots) * 1000,
+ p99_tpot_ms=np.percentile(tpots, 99) * 1000,
+ )
+
+ return metrics, actual_output_lens
+
+
+async def benchmark(
+ backend: str,
+ api_url: str,
+ model_id: str,
+ tokenizer: PreTrainedTokenizerBase,
+ input_requests: List[Tuple[str, int, int]],
+ best_of: int,
+ use_beam_search: bool,
+ request_rate: float,
+ disable_tqdm: bool,
+):
+ if backend in ASYNC_REQUEST_FUNCS:
+ request_func = ASYNC_REQUEST_FUNCS.get(backend)
+ else:
+ raise ValueError(f"Unknown backend: {backend}")
+
+ print(f"Traffic request rate: {request_rate}")
+
+ pbar = None if disable_tqdm else tqdm(total=len(input_requests))
+
+ benchmark_start_time = time.perf_counter()
+ tasks = []
+ async for request in get_request(input_requests, request_rate):
+ prompt, prompt_len, output_len = request
+ request_func_input = RequestFuncInput(
+ model=model_id,
+ prompt=prompt,
+ api_url=api_url,
+ prompt_len=prompt_len,
+ output_len=output_len,
+ best_of=best_of,
+ use_beam_search=use_beam_search,
+ )
+ tasks.append(
+ asyncio.create_task(
+ request_func(request_func_input=request_func_input,
+ pbar=pbar)))
+ outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
+
+ if not disable_tqdm:
+ pbar.close()
+
+ benchmark_duration = time.perf_counter() - benchmark_start_time
+
+ metrics, actual_output_lens = calculate_metrics(
+ input_requests=input_requests,
+ outputs=outputs,
+ dur_s=benchmark_duration,
+ tokenizer=tokenizer,
+ )
+
+ print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
+ print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
+ print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
+ benchmark_duration))
+ print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
+ print("{:<40} {:<10}".format("Total generated tokens:",
+ metrics.total_output))
+ print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
+ metrics.request_throughput))
+ print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
+ metrics.input_throughput))
+ print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
+ metrics.output_throughput))
+ print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
+ print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
+ print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
+ metrics.median_ttft_ms))
+ print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
+ print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
+ n=50,
+ c='-'))
+ print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
+ print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
+ metrics.median_tpot_ms))
+ print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
+ print("=" * 50)
+
+ result = {
+ "duration": benchmark_duration,
+ "completed": metrics.completed,
+ "total_input_tokens": metrics.total_input,
+ "total_output_tokens": metrics.total_output,
+ "request_throughput": metrics.request_throughput,
+ "input_throughput": metrics.input_throughput,
+ "output_throughput": metrics.output_throughput,
+ "mean_ttft_ms": metrics.mean_ttft_ms,
+ "median_ttft_ms": metrics.median_ttft_ms,
+ "p99_ttft_ms": metrics.p99_ttft_ms,
+ "mean_tpot_ms": metrics.mean_tpot_ms,
+ "median_tpot_ms": metrics.median_tpot_ms,
+ "p99_tpot_ms": metrics.p99_tpot_ms,
+ "input_lens": [output.prompt_len for output in outputs],
+ "output_lens": actual_output_lens,
+ "ttfts": [output.ttft for output in outputs],
+ "itls": [output.itl for output in outputs],
+ "generated_texts": [output.generated_text for output in outputs],
+ "errors": [output.error for output in outputs],
+ }
+ return result
+
+
+def main(args: argparse.Namespace):
+ print(args)
+ random.seed(args.seed)
+ np.random.seed(args.seed)
+
+ backend = args.backend
+ model_id = args.model
+ tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
+
+ if args.base_url is not None:
+ api_url = f"{args.base_url}{args.endpoint}"
+ else:
+ api_url = f"http://{args.host}:{args.port}{args.endpoint}"
+
+ tokenizer = get_tokenizer(tokenizer_id,
+ trust_remote_code=args.trust_remote_code)
+
+ if args.dataset is not None:
+ warnings.warn(
+ "The '--dataset' argument will be deprecated in the next "
+ "release. Please use '--dataset-name' and "
+ "'--dataset-path' in the future runs.",
+ stacklevel=2)
+ input_requests = sample_sharegpt_requests(
+ dataset_path=args.dataset,
+ num_requests=args.num_prompts,
+ tokenizer=tokenizer,
+ fixed_output_len=args.sharegpt_output_len,
+ )
+
+ elif args.dataset_name == "sharegpt":
+ input_requests = sample_sharegpt_requests(
+ dataset_path=args.dataset_path,
+ num_requests=args.num_prompts,
+ tokenizer=tokenizer,
+ fixed_output_len=args.sharegpt_output_len,
+ )
+
+ elif args.dataset_name == "sonnet":
+ # Do not format the prompt, pass to message directly
+ if args.backend == "openai-chat":
+ input_requests = sample_sonnet_requests(
+ dataset_path=args.dataset_path,
+ num_requests=args.num_prompts,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
+ tokenizer=tokenizer,
+ )
+ input_requests = [(prompt, prompt_len, output_len)
+ for prompt, prompt_formatted, prompt_len,
+ output_len in input_requests]
+ else:
+ assert (
+ tokenizer.chat_template or tokenizer.default_chat_template
+ ), "Tokenizer/model must have chat template for sonnet dataset."
+ input_requests = sample_sonnet_requests(
+ dataset_path=args.dataset_path,
+ num_requests=args.num_prompts,
+ input_len=args.sonnet_input_len,
+ output_len=args.sonnet_output_len,
+ prefix_len=args.sonnet_prefix_len,
+ tokenizer=tokenizer,
+ )
+ input_requests = [(prompt_formatted, prompt_len, output_len)
+ for prompt, prompt_formatted, prompt_len,
+ output_len in input_requests]
+
+ else:
+ raise ValueError(f"Unknown dataset: {args.dataset_name}")
+
+ benchmark_result = asyncio.run(
+ benchmark(
+ backend=backend,
+ api_url=api_url,
+ model_id=model_id,
+ tokenizer=tokenizer,
+ input_requests=input_requests,
+ best_of=args.best_of,
+ use_beam_search=args.use_beam_search,
+ request_rate=args.request_rate,
+ disable_tqdm=args.disable_tqdm,
+ ))
+
+ # Save config and results to json
+ if args.save_result:
+ result_json = {}
+
+ # Setup
+ current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
+ result_json["date"] = current_dt
+ result_json["backend"] = backend
+ result_json["model_id"] = model_id
+ result_json["tokenizer_id"] = tokenizer_id
+ result_json["best_of"] = args.best_of
+ result_json["use_beam_search"] = args.use_beam_search
+ result_json["num_prompts"] = args.num_prompts
+
+ # Metadata
+ if args.metadata:
+ for item in args.metadata:
+ if "=" in item:
+ kvstring = item.split("=")
+ result_json[kvstring[0].strip()] = kvstring[1].strip()
+ else:
+ raise ValueError(
+ "Invalid metadata format. Please use KEY=VALUE format."
+ )
+
+ # Traffic
+ result_json["request_rate"] = (
+ args.request_rate if args.request_rate < float("inf") else "inf")
+
+ # Merge with benchmark result
+ result_json = {**result_json, **benchmark_result}
+
+ # Save to file
+ base_model_id = model_id.split("/")[-1]
+ file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
+ if args.result_dir:
+ file_name = os.path.join(args.result_dir, file_name)
+ with open(file_name, "w") as outfile:
+ json.dump(result_json, outfile)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Benchmark the online serving throughput.")
+ parser.add_argument(
+ "--backend",
+ type=str,
+ default="vllm",
+ choices=list(ASYNC_REQUEST_FUNCS.keys()),
+ )
+ parser.add_argument(
+ "--base-url",
+ type=str,
+ default=None,
+ help="Server or API base url if not using http host and port.",
+ )
+ parser.add_argument("--host", type=str, default="localhost")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default="/v1/completions",
+ help="API endpoint.",
+ )
+ parser.add_argument(
+ "--dataset",
+ type=str,
+ default=None,
+ help="Path to the ShareGPT dataset, will be deprecated in the "
+ "next release.",
+ )
+ parser.add_argument(
+ "--dataset-name",
+ type=str,
+ default="sharegpt",
+ choices=["sharegpt", "sonnet"],
+ help="Name of the dataset to benchmark on.",
+ )
+ parser.add_argument("--dataset-path",
+ type=str,
+ default=None,
+ help="Path to the dataset.")
+ parser.add_argument(
+ "--model",
+ type=str,
+ required=True,
+ help="Name of the model.",
+ )
+ parser.add_argument(
+ "--tokenizer",
+ type=str,
+ help=
+ "Name or path of the tokenizer, if not using the default tokenizer.",
+ )
+ parser.add_argument(
+ "--best-of",
+ type=int,
+ default=1,
+ help="Generates `best_of` sequences per prompt and "
+ "returns the best one.",
+ )
+ parser.add_argument("--use-beam-search", action="store_true")
+ parser.add_argument(
+ "--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.",
+ )
+ parser.add_argument(
+ "--sharegpt-output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the output length "
+ "from the ShareGPT dataset.")
+ parser.add_argument(
+ "--sonnet-input-len",
+ type=int,
+ default=550,
+ help=
+ "Number of input tokens per request, used only for sonnet dataset.",
+ )
+ parser.add_argument(
+ "--sonnet-output-len",
+ type=int,
+ default=150,
+ help=
+ "Number of output tokens per request, used only for sonnet dataset.",
+ )
+ parser.add_argument(
+ "--sonnet-prefix-len",
+ type=int,
+ default=200,
+ help=
+ "Number of prefix tokens per request, used only for sonnet dataset.",
+ )
+ parser.add_argument(
+ "--request-rate",
+ type=float,
+ default=float("inf"),
+ help="Number of requests per second. If this is inf, "
+ "then all the requests are sent at time 0. "
+ "Otherwise, we use Poisson process to synthesize "
+ "the request arrival times.",
+ )
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument(
+ "--trust-remote-code",
+ action="store_true",
+ help="Trust remote code from huggingface",
+ )
+ parser.add_argument(
+ "--disable-tqdm",
+ action="store_true",
+ help="Specify to disable tqdm progress bar.",
+ )
+ parser.add_argument(
+ "--save-result",
+ action="store_true",
+ help="Specify to save benchmark results to a json file",
+ )
+ parser.add_argument(
+ "--metadata",
+ metavar="KEY=VALUE",
+ nargs="*",
+ help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
+ "for metadata of this run to be saved in the result JSON file "
+ "for record keeping purposes.",
+ )
+ parser.add_argument(
+ "--result-dir",
+ type=str,
+ default=None,
+ help="Specify directory to save benchmark json results."
+ "If not specified, results are saved in the current directory.",
+ )
+
+ args = parser.parse_args()
+ main(args)
diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py
new file mode 100644
index 0000000..695d06e
--- /dev/null
+++ b/benchmarks/benchmark_throughput.py
@@ -0,0 +1,387 @@
+"""Benchmark offline inference throughput."""
+import argparse
+import json
+import random
+import time
+from typing import List, Optional, Tuple
+
+import torch
+from tqdm import tqdm
+from transformers import (AutoModelForCausalLM, AutoTokenizer,
+ PreTrainedTokenizerBase)
+
+from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
+
+
+def sample_requests(
+ dataset_path: str,
+ num_requests: int,
+ tokenizer: PreTrainedTokenizerBase,
+ fixed_output_len: Optional[int],
+) -> List[Tuple[str, int, int]]:
+ if fixed_output_len is not None and fixed_output_len < 4:
+ raise ValueError("output_len too small")
+
+ # Load the dataset.
+ with open(dataset_path) as f:
+ dataset = json.load(f)
+ # Filter out the conversations with less than 2 turns.
+ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
+ # Only keep the first two turns of each conversation.
+ dataset = [(data["conversations"][0]["value"],
+ data["conversations"][1]["value"]) for data in dataset]
+
+ # Shuffle the dataset.
+ random.shuffle(dataset)
+
+ # Filter out sequences that are too long or too short
+ filtered_dataset: List[Tuple[str, int, int]] = []
+ for i in range(len(dataset)):
+ if len(filtered_dataset) == num_requests:
+ break
+
+ # Tokenize the prompts and completions.
+ prompt = dataset[i][0]
+ prompt_token_ids = tokenizer(prompt).input_ids
+ completion = dataset[i][1]
+ completion_token_ids = tokenizer(completion).input_ids
+ prompt_len = len(prompt_token_ids)
+ output_len = len(completion_token_ids
+ ) if fixed_output_len is None else fixed_output_len
+ if prompt_len < 4 or output_len < 4:
+ # Prune too short sequences.
+ continue
+ if prompt_len > 1024 or prompt_len + output_len > 2048:
+ # Prune too long sequences.
+ continue
+ filtered_dataset.append((prompt, prompt_len, output_len))
+
+ return filtered_dataset
+
+
+def run_vllm(
+ requests: List[Tuple[str, int, int]],
+ model: str,
+ tokenizer: str,
+ quantization: Optional[str],
+ tensor_parallel_size: int,
+ seed: int,
+ n: int,
+ use_beam_search: bool,
+ trust_remote_code: bool,
+ dtype: str,
+ max_model_len: Optional[int],
+ enforce_eager: bool,
+ kv_cache_dtype: str,
+ quantization_param_path: Optional[str],
+ device: str,
+ enable_prefix_caching: bool,
+ enable_chunked_prefill: bool,
+ max_num_batched_tokens: int,
+ gpu_memory_utilization: float = 0.9,
+ download_dir: Optional[str] = None,
+) -> float:
+ from vllm import LLM, SamplingParams
+ llm = LLM(
+ model=model,
+ tokenizer=tokenizer,
+ quantization=quantization,
+ tensor_parallel_size=tensor_parallel_size,
+ seed=seed,
+ trust_remote_code=trust_remote_code,
+ dtype=dtype,
+ max_model_len=max_model_len,
+ gpu_memory_utilization=gpu_memory_utilization,
+ enforce_eager=enforce_eager,
+ kv_cache_dtype=kv_cache_dtype,
+ quantization_param_path=quantization_param_path,
+ device=device,
+ enable_prefix_caching=enable_prefix_caching,
+ download_dir=download_dir,
+ enable_chunked_prefill=enable_chunked_prefill,
+ max_num_batched_tokens=max_num_batched_tokens,
+ )
+
+ # Add the requests to the engine.
+ prompts = []
+ sampling_params = []
+ for prompt, _, output_len in requests:
+ prompts.append(prompt)
+ sampling_params.append(
+ SamplingParams(
+ n=n,
+ temperature=0.0 if use_beam_search else 1.0,
+ top_p=1.0,
+ use_beam_search=use_beam_search,
+ ignore_eos=True,
+ max_tokens=output_len,
+ ))
+
+ start = time.perf_counter()
+ llm.generate(prompts, sampling_params, use_tqdm=True)
+ end = time.perf_counter()
+ return end - start
+
+
+def run_hf(
+ requests: List[Tuple[str, int, int]],
+ model: str,
+ tokenizer: PreTrainedTokenizerBase,
+ n: int,
+ use_beam_search: bool,
+ max_batch_size: int,
+ trust_remote_code: bool,
+) -> float:
+ assert not use_beam_search
+ llm = AutoModelForCausalLM.from_pretrained(
+ model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
+ if llm.config.model_type == "llama":
+ # To enable padding in the HF backend.
+ tokenizer.pad_token = tokenizer.eos_token
+ llm = llm.cuda()
+
+ pbar = tqdm(total=len(requests))
+ start = time.perf_counter()
+ batch: List[str] = []
+ max_prompt_len = 0
+ max_output_len = 0
+ for i in range(len(requests)):
+ prompt, prompt_len, output_len = requests[i]
+ # Add the prompt to the batch.
+ batch.append(prompt)
+ max_prompt_len = max(max_prompt_len, prompt_len)
+ max_output_len = max(max_output_len, output_len)
+ if len(batch) < max_batch_size and i != len(requests) - 1:
+ # Check if we can add more requests to the batch.
+ _, next_prompt_len, next_output_len = requests[i + 1]
+ if (max(max_prompt_len, next_prompt_len) +
+ max(max_output_len, next_output_len)) <= 2048:
+ # We can add more requests to the batch.
+ continue
+
+ # Generate the sequences.
+ input_ids = tokenizer(batch, return_tensors="pt",
+ padding=True).input_ids
+ llm_outputs = llm.generate(
+ input_ids=input_ids.cuda(),
+ do_sample=not use_beam_search,
+ num_return_sequences=n,
+ temperature=1.0,
+ top_p=1.0,
+ use_cache=True,
+ max_new_tokens=max_output_len,
+ )
+ # Include the decoding time.
+ tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
+ pbar.update(len(batch))
+
+ # Clear the batch.
+ batch = []
+ max_prompt_len = 0
+ max_output_len = 0
+ end = time.perf_counter()
+ return end - start
+
+
+def run_mii(
+ requests: List[Tuple[str, int, int]],
+ model: str,
+ tensor_parallel_size: int,
+ output_len: int,
+) -> float:
+ from mii import client, serve
+ llm = serve(model, tensor_parallel=tensor_parallel_size)
+ prompts = [prompt for prompt, _, _ in requests]
+
+ start = time.perf_counter()
+ llm.generate(prompts, max_new_tokens=output_len)
+ end = time.perf_counter()
+ client = client(model)
+ client.terminate_server()
+ return end - start
+
+
+def main(args: argparse.Namespace):
+ print(args)
+ random.seed(args.seed)
+
+ # Sample the requests.
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.tokenizer, trust_remote_code=args.trust_remote_code)
+ if args.dataset is None:
+ # Synthesize a prompt with the given input length.
+ prompt = "hi" * (args.input_len - 1)
+ requests = [(prompt, args.input_len, args.output_len)
+ for _ in range(args.num_prompts)]
+ else:
+ requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
+ args.output_len)
+
+ if args.backend == "vllm":
+ elapsed_time = run_vllm(
+ requests, args.model, args.tokenizer, args.quantization,
+ args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
+ args.trust_remote_code, args.dtype, args.max_model_len,
+ args.enforce_eager, args.kv_cache_dtype,
+ args.quantization_param_path, args.device,
+ args.enable_prefix_caching, args.enable_chunked_prefill,
+ args.max_num_batched_tokens, args.gpu_memory_utilization,
+ args.download_dir)
+ elif args.backend == "hf":
+ assert args.tensor_parallel_size == 1
+ elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
+ args.use_beam_search, args.hf_max_batch_size,
+ args.trust_remote_code)
+ elif args.backend == "mii":
+ elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
+ args.output_len)
+ else:
+ raise ValueError(f"Unknown backend: {args.backend}")
+ total_num_tokens = sum(prompt_len + output_len
+ for _, prompt_len, output_len in requests)
+ print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
+ f"{total_num_tokens / elapsed_time:.2f} tokens/s")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Benchmark the throughput.")
+ parser.add_argument("--backend",
+ type=str,
+ choices=["vllm", "hf", "mii"],
+ default="vllm")
+ parser.add_argument("--dataset",
+ type=str,
+ default=None,
+ help="Path to the dataset.")
+ parser.add_argument("--input-len",
+ type=int,
+ default=None,
+ help="Input prompt length for each request")
+ parser.add_argument("--output-len",
+ type=int,
+ default=None,
+ help="Output length for each request. Overrides the "
+ "output length from the dataset.")
+ parser.add_argument("--model", type=str, default="facebook/opt-125m")
+ parser.add_argument("--tokenizer", type=str, default=None)
+ parser.add_argument('--quantization',
+ '-q',
+ choices=[*QUANTIZATION_METHODS, None],
+ default=None)
+ parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
+ parser.add_argument("--n",
+ type=int,
+ default=1,
+ help="Number of generated sequences per prompt.")
+ parser.add_argument("--use-beam-search", action="store_true")
+ parser.add_argument("--num-prompts",
+ type=int,
+ default=1000,
+ help="Number of prompts to process.")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--hf-max-batch-size",
+ type=int,
+ default=None,
+ help="Maximum batch size for HF backend.")
+ parser.add_argument('--trust-remote-code',
+ action='store_true',
+ help='trust remote code from huggingface')
+ parser.add_argument(
+ '--max-model-len',
+ type=int,
+ default=None,
+ help='Maximum length of a sequence (including prompt and output). '
+ 'If None, will be derived from the model.')
+ parser.add_argument(
+ '--dtype',
+ type=str,
+ default='auto',
+ choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
+ help='data type for model weights and activations. '
+ 'The "auto" option will use FP16 precision '
+ 'for FP32 and FP16 models, and BF16 precision '
+ 'for BF16 models.')
+ parser.add_argument('--gpu-memory-utilization',
+ type=float,
+ default=0.9,
+ help='the fraction of GPU memory to be used for '
+ 'the model executor, which can range from 0 to 1.'
+ 'If unspecified, will use the default value of 0.9.')
+ parser.add_argument("--enforce-eager",
+ action="store_true",
+ help="enforce eager execution")
+ parser.add_argument(
+ "--kv-cache-dtype",
+ type=str,
+ choices=["auto", "fp8"],
+ default="auto",
+ help=
+ 'Data type for kv cache storage. If "auto", will use model data type. '
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+ 'common inference criteria.')
+ parser.add_argument(
+ '--quantization-param-path',
+ type=str,
+ default=None,
+ help='Path to the JSON file containing the KV cache scaling factors. '
+ 'This should generally be supplied, when KV cache dtype is FP8. '
+ 'Otherwise, KV cache scaling factors default to 1.0, which may cause '
+ 'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
+ 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
+ 'instead supported for common inference criteria.')
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ choices=["cuda", "cpu"],
+ help='device type for vLLM execution, supporting CUDA and CPU.')
+ parser.add_argument(
+ "--enable-prefix-caching",
+ action='store_true',
+ help="enable automatic prefix caching for vLLM backend.")
+ parser.add_argument("--enable-chunked-prefill",
+ action='store_true',
+ help="enable chunked prefill for vLLM backend.")
+ parser.add_argument('--max-num-batched-tokens',
+ type=int,
+ default=None,
+ help='maximum number of batched tokens per '
+ 'iteration')
+ parser.add_argument('--download-dir',
+ type=str,
+ default=None,
+ help='directory to download and load the weights, '
+ 'default to the default cache dir of huggingface')
+ args = parser.parse_args()
+ if args.tokenizer is None:
+ args.tokenizer = args.model
+ if args.dataset is None:
+ assert args.input_len is not None
+ assert args.output_len is not None
+ else:
+ assert args.input_len is None
+
+ if args.backend == "vllm":
+ if args.hf_max_batch_size is not None:
+ raise ValueError("HF max batch size is only for HF backend.")
+ elif args.backend == "hf":
+ if args.hf_max_batch_size is None:
+ raise ValueError("HF max batch size is required for HF backend.")
+ if args.quantization is not None:
+ raise ValueError("Quantization is only for vLLM backend.")
+ elif args.backend == "mii":
+ if args.dtype != "auto":
+ raise ValueError("dtype must be auto for MII backend.")
+ if args.n != 1:
+ raise ValueError("n must be 1 for MII backend.")
+ if args.use_beam_search:
+ raise ValueError("Beam search is not supported for MII backend.")
+ if args.quantization is not None:
+ raise ValueError("Quantization is only for vLLM backend.")
+ if args.hf_max_batch_size is not None:
+ raise ValueError("HF max batch size is only for HF backend.")
+ if args.tokenizer != args.model:
+ raise ValueError("Tokenizer must be the same as the model for MII "
+ "backend.")
+ main(args)
diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py
new file mode 100644
index 0000000..5939294
--- /dev/null
+++ b/benchmarks/kernels/benchmark_aqlm.py
@@ -0,0 +1,302 @@
+import argparse
+import os
+import sys
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+
+from vllm import _custom_ops as ops
+from vllm.model_executor.layers.quantization.aqlm import (
+ dequantize_weight, generic_dequantize_gemm, get_int_dtype,
+ optimized_dequantize_gemm)
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+
+def torch_mult(
+ input: torch.Tensor, # [..., in_features]
+ weights: torch.Tensor,
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+) -> torch.Tensor:
+ output = F.linear(input, weights)
+ return output
+
+
+def dequant_out_scale(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ if bias is None:
+ output = F.linear(input, weights, bias)
+ orig_shape = output.shape
+ flattened_output = output.view(-1, output.size(-1))
+ f_scales = scales.view(-1, scales.shape[0])
+ b_scales = f_scales.expand(flattened_output.shape[0], -1)
+ flattened_output *= b_scales
+ return flattened_output.view(orig_shape)
+ else:
+ b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
+ -1, weights.shape[1])
+ weights *= b_scales
+ return F.linear(input, weights, bias)
+
+
+def dequant_weight_scale(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
+ -1, weights.shape[1])
+ weights *= b_scales
+ return F.linear(input, weights, bias)
+
+
+def dequant_no_scale(
+ input: torch.Tensor, # [..., in_features]
+ codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
+ codebooks: torch.
+ Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
+ scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
+ output_partition_sizes: torch.IntTensor,
+ bias: Optional[torch.Tensor],
+) -> torch.Tensor:
+
+ weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
+
+ return F.linear(input, weights, bias)
+
+
+# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
+# the generic pytorch version.
+# Just visual comparison.
+def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
+
+ n = parts.sum().item()
+
+ device = torch.device('cuda:0')
+
+ code_range = (1 << bits) // 2
+ ingroups = 8
+
+ codes = torch.randint(-code_range,
+ code_range,
+ size=(n, k // ingroups, nbooks),
+ dtype=get_int_dtype(bits),
+ device=device)
+
+ codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
+ dtype=torch.float16,
+ device=device)
+
+ count = 0
+ for index in range(16):
+ for i in range(8):
+ for book in range(nbooks):
+ codebooks[book, index, 0, i] = count * (10**book)
+ count += 1
+
+ print("codes shape", codes.shape)
+
+ for i in range(16):
+ for book in range(nbooks):
+ codes[0, i, book] = i
+ codes[0, -i, book] = i
+
+ weights = dequantize_weight(codes, codebooks, None)
+ weights2 = ops.aqlm_dequant(codes, codebooks, parts)
+
+ print("weights shape:", weights.shape)
+ print("weights2 shape:", weights2.shape)
+
+ print("weights are:", weights)
+ print("weights2 are:", weights2)
+
+ print("first 128 weights are", weights[0, 0:128].to(torch.int32))
+ print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
+
+ print("last 128 weights are", weights[0, -128:])
+ print("last 128 weights2 are:", weights2[0, -128:])
+
+
+def main():
+
+ parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
+
+ # Add arguments
+ parser.add_argument("--nbooks",
+ type=int,
+ default=1,
+ help="Number of codebooks (default: 1)")
+ parser.add_argument("--bits",
+ type=int,
+ default=16,
+ help="Number of bits per code element (default: 16)")
+ parser.add_argument(
+ "--test",
+ type=bool,
+ default=False,
+ help="Run the decompression/dequant tester rather than benchmarking "
+ "(default: False)")
+
+ # Parse the arguments
+ args = parser.parse_args()
+
+ # Extract values
+ nbooks = args.nbooks
+ bits = args.bits
+
+ if args.test:
+ dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
+ return
+
+ # Otherwise, benchmark.
+ methods = [
+ ops.aqlm_gemm,
+ dequant_out_scale,
+ generic_dequantize_gemm,
+ optimized_dequantize_gemm,
+ dequant_weight_scale,
+ torch_mult,
+ dequant_no_scale,
+ ]
+
+ filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
+ print(f"writing benchmarks to file {filename}")
+ with open(filename, "w") as f:
+ sys.stdout = f
+
+ print('m | k | n | n parts', end='')
+ for method in methods:
+ print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
+ print('')
+
+ # These are reasonable prefill sizes.
+ ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
+ (4096, (11008, 11008)), (11008, (4096, )))
+
+ # reasonable ranges for m.
+ for m in [
+ 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
+ 128, 256, 512, 1024, 1536, 2048, 3072, 4096
+ ]:
+ print(f'{m}', file=sys.__stdout__)
+ for ksp in ksandpartions:
+ run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
+ methods)
+
+ sys.stdout = sys.__stdout__
+
+
+def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
+ methods):
+
+ # I didn't see visible improvements from increasing these, but feel free :)
+ num_warmup_trials = 1
+ num_trials = 1
+
+ num_calls = 100
+
+ # warmup.
+ for method in methods:
+ for _ in range(num_warmup_trials):
+ run_timing(
+ num_calls=num_calls,
+ m=m,
+ k=k,
+ parts=parts,
+ nbooks=nbooks,
+ bits=bits,
+ method=method,
+ )
+
+ n = parts.sum().item()
+ print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
+
+ for method in methods:
+ best_time_us = 1e20
+ for _ in range(num_trials):
+ kernel_dur_ms = run_timing(
+ num_calls=num_calls,
+ m=m,
+ k=k,
+ parts=parts,
+ nbooks=nbooks,
+ bits=bits,
+ method=method,
+ )
+
+ kernel_dur_us = 1000 * kernel_dur_ms
+
+ if kernel_dur_us < best_time_us:
+ best_time_us = kernel_dur_us
+
+ print(f' | {kernel_dur_us:.0f}', end='')
+
+ print('')
+
+
+def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
+ nbooks: int, bits: int, method) -> float:
+
+ n = parts.sum().item()
+
+ device = torch.device('cuda:0')
+
+ input = torch.randn((1, m, k), dtype=torch.float16, device=device)
+
+ code_range = (1 << bits) // 2
+ ingroups = 8
+
+ codes = torch.randint(-code_range,
+ code_range,
+ size=(n, k // ingroups, nbooks),
+ dtype=get_int_dtype(bits),
+ device=device)
+
+ codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
+ dtype=torch.float16,
+ device=device)
+
+ scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
+
+ # for comparison to just a pytorch mult.
+ weights = torch.randn((n, k), dtype=torch.float16, device=device)
+
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+
+ start_event.record()
+
+ if method is torch_mult:
+ for i in range(num_calls):
+ torch_mult(input, weights, scales)
+ else:
+ for i in range(num_calls):
+ method(input, codes, codebooks, scales, parts, None)
+
+ end_event.record()
+ end_event.synchronize()
+
+ dur_ms = start_event.elapsed_time(end_event) / num_calls
+ return dur_ms
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/benchmarks/kernels/benchmark_mixtral_moe.py b/benchmarks/kernels/benchmark_mixtral_moe.py
new file mode 100644
index 0000000..5280b21
--- /dev/null
+++ b/benchmarks/kernels/benchmark_mixtral_moe.py
@@ -0,0 +1,215 @@
+import argparse
+import json
+import os
+import sys
+
+import torch
+import torch.nn.functional as F
+import triton
+from tqdm import tqdm
+
+from vllm.model_executor.layers.fused_moe import (fused_moe,
+ get_config_file_name)
+
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+
+def main(dtype: str):
+ method = fused_moe
+ for bs in [
+ 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
+ 2048, 3072, 4096
+ ]:
+ run_grid(bs, method=method, dtype=dtype)
+
+
+def run_grid(bs, method, dtype: str):
+ d_model = 4096
+ num_total_experts = 8
+ top_k = 2
+ tp_size = 2
+ model_intermediate_size = 14336
+ num_layers = 32
+ num_calls = 100
+
+ num_warmup_trials = 1
+ num_trials = 1
+
+ configs = []
+
+ for block_size_n in [32, 64, 128, 256]:
+ for block_size_m in [16, 32, 64, 128, 256]:
+ for block_size_k in [64, 128, 256]:
+ for group_size_m in [1, 16, 32, 64]:
+ for num_warps in [4, 8]:
+ for num_stages in [2, 3, 4, 5]:
+ configs.append({
+ "BLOCK_SIZE_M": block_size_m,
+ "BLOCK_SIZE_N": block_size_n,
+ "BLOCK_SIZE_K": block_size_k,
+ "GROUP_SIZE_M": group_size_m,
+ "num_warps": num_warps,
+ "num_stages": num_stages,
+ })
+
+ best_config = None
+ best_time_us = 1e20
+
+ print(f'{tp_size=} {bs=}')
+
+ for config in tqdm(configs):
+ # warmup
+ try:
+ for _ in range(num_warmup_trials):
+ run_timing(
+ num_calls=num_calls,
+ bs=bs,
+ d_model=d_model,
+ num_total_experts=num_total_experts,
+ top_k=top_k,
+ tp_size=tp_size,
+ model_intermediate_size=model_intermediate_size,
+ method=method,
+ config=config,
+ dtype=dtype,
+ )
+ except triton.runtime.autotuner.OutOfResources:
+ continue
+
+ # trial
+ for _ in range(num_trials):
+ kernel_dur_ms = run_timing(
+ num_calls=num_calls,
+ bs=bs,
+ d_model=d_model,
+ num_total_experts=num_total_experts,
+ top_k=top_k,
+ tp_size=tp_size,
+ model_intermediate_size=model_intermediate_size,
+ method=method,
+ config=config,
+ dtype=dtype,
+ )
+
+ kernel_dur_us = 1000 * kernel_dur_ms
+ model_dur_ms = kernel_dur_ms * num_layers
+
+ if kernel_dur_us < best_time_us:
+ best_config = config
+ best_time_us = kernel_dur_us
+
+ tqdm.write(
+ f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
+ f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
+ f'{d_model=} {model_intermediate_size=} {num_layers=}')
+
+ print("best_time_us", best_time_us)
+ print("best_config", best_config)
+
+ # holds Dict[str, Dict[str, int]]
+ filename = get_config_file_name(num_total_experts,
+ model_intermediate_size // tp_size,
+ "float8" if dtype == "float8" else None)
+ print(f"writing config to file {filename}")
+ existing_content = {}
+ if os.path.exists(filename):
+ with open(filename, "r") as f:
+ existing_content = json.load(f)
+ existing_content[str(bs)] = best_config
+ with open(filename, "w") as f:
+ json.dump(existing_content, f, indent=4)
+ f.write("\n")
+
+
+def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
+ top_k: int, tp_size: int, model_intermediate_size: int, method,
+ config, dtype: str) -> float:
+ shard_intermediate_size = model_intermediate_size // tp_size
+
+ hidden_states = torch.rand(
+ (bs, d_model),
+ device="cuda:0",
+ dtype=torch.float16,
+ )
+
+ w1 = torch.rand(
+ (num_total_experts, 2 * shard_intermediate_size, d_model),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ w2 = torch.rand(
+ (num_total_experts, d_model, shard_intermediate_size),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+
+ w1_scale = None
+ w2_scale = None
+ a1_scale = None
+ a2_scale = None
+
+ if dtype == "float8":
+ w1 = w1.to(torch.float8_e4m3fn)
+ w2 = w2.to(torch.float8_e4m3fn)
+ w1_scale = torch.ones(num_total_experts,
+ device=hidden_states.device,
+ dtype=torch.float32)
+ w2_scale = torch.ones(num_total_experts,
+ device=hidden_states.device,
+ dtype=torch.float32)
+ a1_scale = torch.ones(1,
+ device=hidden_states.device,
+ dtype=torch.float32)
+ a2_scale = torch.ones(1,
+ device=hidden_states.device,
+ dtype=torch.float32)
+
+ gating_output = F.softmax(torch.rand(
+ (num_calls, bs, num_total_experts),
+ device=hidden_states.device,
+ dtype=torch.float32,
+ ),
+ dim=-1)
+
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+
+ start_event.record()
+ for i in range(num_calls):
+ hidden_states = method(
+ hidden_states=hidden_states,
+ w1=w1,
+ w2=w2,
+ w1_scale=w1_scale,
+ w2_scale=w2_scale,
+ a1_scale=a1_scale,
+ a2_scale=a2_scale,
+ gating_output=gating_output[i],
+ topk=2,
+ renormalize=True,
+ inplace=True,
+ override_config=config,
+ use_fp8=dtype == "float8",
+ )
+ end_event.record()
+ end_event.synchronize()
+
+ dur_ms = start_event.elapsed_time(end_event) / num_calls
+ return dur_ms
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ prog='benchmark_mixtral_moe',
+ description='Benchmark and tune the fused_moe kernel',
+ )
+ parser.add_argument(
+ '--dtype',
+ type=str,
+ default='auto',
+ choices=['float8', 'float16'],
+ help='Data type used for fused_moe kernel computations',
+ )
+ args = parser.parse_args()
+ sys.exit(main(args.dtype))
diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
new file mode 100644
index 0000000..ca7967c
--- /dev/null
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -0,0 +1,211 @@
+import argparse
+import random
+import time
+from typing import Optional
+
+import torch
+
+from vllm import _custom_ops as ops
+from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
+
+NUM_BLOCKS = 1024
+PARTITION_SIZE = 512
+
+
+@torch.inference_mode()
+def main(
+ version: str,
+ num_seqs: int,
+ seq_len: int,
+ num_query_heads: int,
+ num_kv_heads: int,
+ head_size: int,
+ use_alibi: bool,
+ block_size: int,
+ dtype: torch.dtype,
+ seed: int,
+ do_profile: bool,
+ device: str = "cuda",
+ kv_cache_dtype: Optional[str] = None,
+) -> None:
+ random.seed(seed)
+ torch.random.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+ scale = float(1.0 / (head_size**0.5))
+ query = torch.empty(num_seqs,
+ num_query_heads,
+ head_size,
+ dtype=dtype,
+ device=device)
+ query.uniform_(-scale, scale)
+
+ assert num_query_heads % num_kv_heads == 0
+ alibi_slopes = None
+ if use_alibi:
+ alibi_slopes = torch.randn(num_query_heads,
+ dtype=torch.float,
+ device=device)
+
+ seq_lens = [seq_len for _ in range(num_seqs)]
+ max_seq_len = max(seq_lens)
+ seq_lens = torch.tensor(seq_lens, dtype=torch.int, device=device)
+
+ # Create the block tables.
+ max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
+ block_tables = []
+ for _ in range(num_seqs):
+ block_table = [
+ random.randint(0, NUM_BLOCKS - 1)
+ for _ in range(max_num_blocks_per_seq)
+ ]
+ block_tables.append(block_table)
+ block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
+
+ # Create the KV cache.
+ key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
+ block_size,
+ 1,
+ num_kv_heads,
+ head_size,
+ kv_cache_dtype,
+ dtype,
+ device=device)
+ key_cache, value_cache = key_caches[0], value_caches[0]
+
+ # Prepare for the paged attention kernel.
+ output = torch.empty_like(query)
+ if version == "v2":
+ num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
+ tmp_output = torch.empty(
+ size=(num_seqs, num_query_heads, num_partitions, head_size),
+ dtype=output.dtype,
+ device=output.device,
+ )
+ exp_sums = torch.empty(
+ size=(num_seqs, num_query_heads, num_partitions),
+ dtype=torch.float32,
+ device=output.device,
+ )
+ max_logits = torch.empty_like(exp_sums)
+
+ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
+ torch.cuda.synchronize()
+ if profile:
+ torch.cuda.cudart().cudaProfilerStart()
+ start_time = time.perf_counter()
+
+ # Using default kv_scale
+ kv_scale = 1.0
+
+ for _ in range(num_iters):
+ if version == "v1":
+ ops.paged_attention_v1(
+ output,
+ query,
+ key_cache,
+ value_cache,
+ num_kv_heads,
+ scale,
+ block_tables,
+ seq_lens,
+ block_size,
+ max_seq_len,
+ alibi_slopes,
+ kv_cache_dtype,
+ kv_scale,
+ )
+ elif version == "v2":
+ ops.paged_attention_v2(
+ output,
+ exp_sums,
+ max_logits,
+ tmp_output,
+ query,
+ key_cache,
+ value_cache,
+ num_kv_heads,
+ scale,
+ block_tables,
+ seq_lens,
+ block_size,
+ max_seq_len,
+ alibi_slopes,
+ kv_cache_dtype,
+ kv_scale,
+ )
+ else:
+ raise ValueError(f"Invalid version: {version}")
+ torch.cuda.synchronize()
+
+ end_time = time.perf_counter()
+ if profile:
+ torch.cuda.cudart().cudaProfilerStart()
+ return (end_time - start_time) / num_iters
+
+ # Warmup.
+ print("Warming up...")
+ run_benchmark = run_cuda_benchmark
+ run_benchmark(num_iters=3, profile=False)
+
+ # Benchmark.
+ if do_profile:
+ latency = run_benchmark(num_iters=1, profile=True)
+ else:
+ latency = run_benchmark(num_iters=100, profile=False)
+ print(f"Kernel running time: {latency * 1000000:.3f} us")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="Benchmark the paged attention kernel.")
+ parser.add_argument("--version",
+ type=str,
+ choices=["v1", "v2"],
+ default="v2")
+ parser.add_argument("--batch-size", type=int, default=8)
+ parser.add_argument("--seq_len", type=int, default=4096)
+ parser.add_argument("--num-query-heads", type=int, default=64)
+ parser.add_argument("--num-kv-heads", type=int, default=8)
+ parser.add_argument("--head-size",
+ type=int,
+ choices=[64, 80, 96, 112, 128, 256],
+ default=128)
+ parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
+ parser.add_argument("--use-alibi", action="store_true")
+ parser.add_argument("--dtype",
+ type=str,
+ choices=["half", "bfloat16", "float"],
+ default="half")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--profile", action="store_true")
+ parser.add_argument(
+ "--kv-cache-dtype",
+ type=str,
+ choices=["auto", "fp8"],
+ default="auto",
+ help=
+ 'Data type for kv cache storage. If "auto", will use model data type. '
+ 'FP8_E5M2 (without scaling) is only supported on cuda version greater '
+ 'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
+ 'common inference criteria.')
+ args = parser.parse_args()
+ print(args)
+
+ if args.num_query_heads % args.num_kv_heads != 0:
+ raise ValueError("num_query_heads must be divisible by num_kv_heads")
+ main(
+ version=args.version,
+ num_seqs=args.batch_size,
+ seq_len=args.seq_len,
+ num_query_heads=args.num_query_heads,
+ num_kv_heads=args.num_kv_heads,
+ head_size=args.head_size,
+ block_size=args.block_size,
+ use_alibi=args.use_alibi,
+ dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
+ seed=args.seed,
+ do_profile=args.profile,
+ kv_cache_dtype=args.kv_cache_dtype,
+ )
diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py
new file mode 100644
index 0000000..9188e81
--- /dev/null
+++ b/benchmarks/kernels/benchmark_rope.py
@@ -0,0 +1,121 @@
+import argparse
+from itertools import accumulate
+from typing import Optional
+
+import nvtx
+import torch
+
+from vllm.model_executor.layers.rotary_embedding import get_rope
+
+
+def benchmark_rope_kernels_multi_lora(
+ is_neox_style: bool,
+ batch_size: int,
+ seq_len: int,
+ num_heads: int,
+ head_size: int,
+ rotary_dim: Optional[int],
+ dtype: torch.dtype,
+ seed: int,
+ device: str,
+ max_position: int = 8192,
+ base: int = 10000,
+) -> None:
+ torch.random.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.set_default_device(device)
+ if rotary_dim is None:
+ rotary_dim = head_size
+ # silulating serving 4 LoRAs
+ scaling_factors = [1, 2, 4, 8]
+ # batched RoPE can take multiple scaling factors
+ batched_rope = get_rope(head_size, rotary_dim, max_position, base,
+ is_neox_style, {
+ "type": "linear",
+ "factor": tuple(scaling_factors)
+ })
+ # non-batched RoPE takes only one scaling factor, we create multiple
+ # instances to simulate the same behavior
+ non_batched_ropes = []
+ for scaling_factor in scaling_factors:
+ non_batched_ropes.append(
+ get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
+ {
+ "type": "linear",
+ "factor": (scaling_factor, )
+ }))
+
+ positions = torch.randint(0, max_position, (batch_size, seq_len))
+ query = torch.randn(batch_size,
+ seq_len,
+ num_heads * head_size,
+ dtype=dtype)
+ key = torch.randn_like(query)
+
+ # create query offsets for batched RoPE, we concat multiple kv cache
+ # together and each query needs to find the right kv cache of its type
+ offset_map = torch.tensor(
+ list(
+ accumulate([0] + [
+ max_position * scaling_factor * 2
+ for scaling_factor in scaling_factors[:-1]
+ ])))
+ query_types = torch.randint(0,
+ len(scaling_factors), (batch_size, seq_len),
+ device=device)
+ # map query types to offsets
+ query_offsets = offset_map[query_types]
+ # the kernel takes flattened offsets
+ flatten_offsets = query_offsets.flatten()
+
+ # batched queries of the same type together for non-batched RoPE
+ queries = [query[query_types == i] for i in range(len(scaling_factors))]
+ keys = [key[query_types == i] for i in range(len(scaling_factors))]
+ packed_qkr = zip(queries, keys, non_batched_ropes)
+ # synchronize before start timing
+ torch.cuda.synchronize()
+ with nvtx.annotate("non-batched", color="yellow"):
+ for q, k, r in packed_qkr:
+ r.forward(positions, q, k)
+ torch.cuda.synchronize()
+ with nvtx.annotate("batched", color="green"):
+ batched_rope.forward(positions, query, key, flatten_offsets)
+ torch.cuda.synchronize()
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="Benchmark the rotary embedding kernels.")
+ parser.add_argument("--is-neox-style", type=bool, default=True)
+ parser.add_argument("--batch-size", type=int, default=16)
+ parser.add_argument("--seq-len", type=int, default=512)
+ parser.add_argument("--num-heads", type=int, default=8)
+ parser.add_argument("--head-size",
+ type=int,
+ choices=[64, 80, 96, 112, 128, 256],
+ default=128)
+ parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
+ parser.add_argument("--dtype",
+ type=str,
+ choices=["bfloat16", "float"],
+ default="float")
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--device",
+ type=str,
+ choices=["cuda:0", "cuda:1"],
+ default="cuda:0")
+ args = parser.parse_args()
+ print(args)
+
+ benchmark_rope_kernels_multi_lora(
+ is_neox_style=args.is_neox_style,
+ batch_size=args.batch_size,
+ seq_len=args.seq_len,
+ num_heads=args.num_heads,
+ head_size=args.head_size,
+ rotary_dim=args.rotary_dim,
+ dtype=getattr(torch, args.dtype),
+ seed=args.seed,
+ device=args.device,
+ )
diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh
new file mode 100755
index 0000000..64d3c4f
--- /dev/null
+++ b/benchmarks/launch_tgi_server.sh
@@ -0,0 +1,16 @@
+#!/bin/bash
+
+PORT=8000
+MODEL=$1
+TOKENS=$2
+
+docker run --gpus all --shm-size 1g -p $PORT:80 \
+ -v $PWD/data:/data \
+ ghcr.io/huggingface/text-generation-inference:1.4.0 \
+ --model-id $MODEL \
+ --sharded false \
+ --max-input-length 1024 \
+ --max-total-tokens 2048 \
+ --max-best-of 5 \
+ --max-concurrent-requests 5000 \
+ --max-batch-total-tokens $TOKENS
diff --git a/benchmarks/sonnet.txt b/benchmarks/sonnet.txt
new file mode 100644
index 0000000..34c444e
--- /dev/null
+++ b/benchmarks/sonnet.txt
@@ -0,0 +1,518 @@
+FROM fairest creatures we desire increase,
+That thereby beauty's rose might never die,
+But as the riper should by time decease,
+His tender heir might bear his memory:
+But thou, contracted to thine own bright eyes,
+Feed'st thy light'st flame with self-substantial fuel,
+Making a famine where abundance lies,
+Thyself thy foe, to thy sweet self too cruel.
+Thou that art now the world's fresh ornament
+And only herald to the gaudy spring,
+Within thine own bud buriest thy content
+And, tender churl, makest waste in niggarding.
+Pity the world, or else this glutton be,
+To eat the world's due, by the grave and thee.
+When forty winters shall beseige thy brow,
+And dig deep trenches in thy beauty's field,
+Thy youth's proud livery, so gazed on now,
+Will be a tatter'd weed, of small worth held:
+Then being ask'd where all thy beauty lies,
+Where all the treasure of thy lusty days,
+To say, within thine own deep-sunken eyes,
+Were an all-eating shame and thriftless praise.
+How much more praise deserved thy beauty's use,
+If thou couldst answer 'This fair child of mine
+Shall sum my count and make my old excuse,'
+Proving his beauty by succession thine!
+This were to be new made when thou art old,
+And see thy blood warm when thou feel'st it cold.
+Look in thy glass, and tell the face thou viewest
+Now is the time that face should form another;
+Whose fresh repair if now thou not renewest,
+Thou dost beguile the world, unbless some mother.
+For where is she so fair whose unear'd womb
+Disdains the tillage of thy husbandry?
+Or who is he so fond will be the tomb
+Of his self-love, to stop posterity?
+Thou art thy mother's glass, and she in thee
+Calls back the lovely April of her prime:
+So thou through windows of thine age shall see
+Despite of wrinkles this thy golden time.
+But if thou live, remember'd not to be,
+Die single, and thine image dies with thee.
+Unthrifty loveliness, why dost thou spend
+Upon thyself thy beauty's legacy?
+Nature's bequest gives nothing but doth lend,
+And being frank she lends to those are free.
+Then, beauteous niggard, why dost thou abuse
+The bounteous largess given thee to give?
+Profitless usurer, why dost thou use
+So great a sum of sums, yet canst not live?
+For having traffic with thyself alone,
+Thou of thyself thy sweet self dost deceive.
+Then how, when nature calls thee to be gone,
+What acceptable audit canst thou leave?
+Thy unused beauty must be tomb'd with thee,
+Which, used, lives th' executor to be.
+Those hours, that with gentle work did frame
+The lovely gaze where every eye doth dwell,
+Will play the tyrants to the very same
+And that unfair which fairly doth excel:
+For never-resting time leads summer on
+To hideous winter and confounds him there;
+Sap cheque'd with frost and lusty leaves quite gone,
+Beauty o'ersnow'd and bareness every where:
+Then, were not summer's distillation left,
+A liquid prisoner pent in walls of glass,
+Beauty's effect with beauty were bereft,
+Nor it nor no remembrance what it was:
+But flowers distill'd though they with winter meet,
+Leese but their show; their substance still lives sweet.
+Then let not winter's ragged hand deface
+In thee thy summer, ere thou be distill'd:
+Make sweet some vial; treasure thou some place
+With beauty's treasure, ere it be self-kill'd.
+That use is not forbidden usury,
+Which happies those that pay the willing loan;
+That's for thyself to breed another thee,
+Or ten times happier, be it ten for one;
+Ten times thyself were happier than thou art,
+If ten of thine ten times refigured thee:
+Then what could death do, if thou shouldst depart,
+Leaving thee living in posterity?
+Be not self-will'd, for thou art much too fair
+To be death's conquest and make worms thine heir.
+Lo! in the orient when the gracious light
+Lifts up his burning head, each under eye
+Doth homage to his new-appearing sight,
+Serving with looks his sacred majesty;
+And having climb'd the steep-up heavenly hill,
+Resembling strong youth in his middle age,
+yet mortal looks adore his beauty still,
+Attending on his golden pilgrimage;
+But when from highmost pitch, with weary car,
+Like feeble age, he reeleth from the day,
+The eyes, 'fore duteous, now converted are
+From his low tract and look another way:
+So thou, thyself out-going in thy noon,
+Unlook'd on diest, unless thou get a son.
+Music to hear, why hear'st thou music sadly?
+Sweets with sweets war not, joy delights in joy.
+Why lovest thou that which thou receivest not gladly,
+Or else receivest with pleasure thine annoy?
+If the true concord of well-tuned sounds,
+By unions married, do offend thine ear,
+They do but sweetly chide thee, who confounds
+In singleness the parts that thou shouldst bear.
+Mark how one string, sweet husband to another,
+Strikes each in each by mutual ordering,
+Resembling sire and child and happy mother
+Who all in one, one pleasing note do sing:
+Whose speechless song, being many, seeming one,
+Sings this to thee: 'thou single wilt prove none.'
+Is it for fear to wet a widow's eye
+That thou consumest thyself in single life?
+Ah! if thou issueless shalt hap to die.
+The world will wail thee, like a makeless wife;
+The world will be thy widow and still weep
+That thou no form of thee hast left behind,
+When every private widow well may keep
+By children's eyes her husband's shape in mind.
+Look, what an unthrift in the world doth spend
+Shifts but his place, for still the world enjoys it;
+But beauty's waste hath in the world an end,
+And kept unused, the user so destroys it.
+No love toward others in that bosom sits
+That on himself such murderous shame commits.
+For shame! deny that thou bear'st love to any,
+Who for thyself art so unprovident.
+Grant, if thou wilt, thou art beloved of many,
+But that thou none lovest is most evident;
+For thou art so possess'd with murderous hate
+That 'gainst thyself thou stick'st not to conspire.
+Seeking that beauteous roof to ruinate
+Which to repair should be thy chief desire.
+O, change thy thought, that I may change my mind!
+Shall hate be fairer lodged than gentle love?
+Be, as thy presence is, gracious and kind,
+Or to thyself at least kind-hearted prove:
+Make thee another self, for love of me,
+That beauty still may live in thine or thee.
+As fast as thou shalt wane, so fast thou growest
+In one of thine, from that which thou departest;
+And that fresh blood which youngly thou bestowest
+Thou mayst call thine when thou from youth convertest.
+Herein lives wisdom, beauty and increase:
+Without this, folly, age and cold decay:
+If all were minded so, the times should cease
+And threescore year would make the world away.
+Let those whom Nature hath not made for store,
+Harsh featureless and rude, barrenly perish:
+Look, whom she best endow'd she gave the more;
+Which bounteous gift thou shouldst in bounty cherish:
+She carved thee for her seal, and meant thereby
+Thou shouldst print more, not let that copy die.
+When I do count the clock that tells the time,
+And see the brave day sunk in hideous night;
+When I behold the violet past prime,
+And sable curls all silver'd o'er with white;
+When lofty trees I see barren of leaves
+Which erst from heat did canopy the herd,
+And summer's green all girded up in sheaves
+Borne on the bier with white and bristly beard,
+Then of thy beauty do I question make,
+That thou among the wastes of time must go,
+Since sweets and beauties do themselves forsake
+And die as fast as they see others grow;
+And nothing 'gainst Time's scythe can make defence
+Save breed, to brave him when he takes thee hence.
+O, that you were yourself! but, love, you are
+No longer yours than you yourself here live:
+Against this coming end you should prepare,
+And your sweet semblance to some other give.
+So should that beauty which you hold in lease
+Find no determination: then you were
+Yourself again after yourself's decease,
+When your sweet issue your sweet form should bear.
+Who lets so fair a house fall to decay,
+Which husbandry in honour might uphold
+Against the stormy gusts of winter's day
+And barren rage of death's eternal cold?
+O, none but unthrifts! Dear my love, you know
+You had a father: let your son say so.
+Not from the stars do I my judgment pluck;
+And yet methinks I have astronomy,
+But not to tell of good or evil luck,
+Of plagues, of dearths, or seasons' quality;
+Nor can I fortune to brief minutes tell,
+Pointing to each his thunder, rain and wind,
+Or say with princes if it shall go well,
+By oft predict that I in heaven find:
+But from thine eyes my knowledge I derive,
+And, constant stars, in them I read such art
+As truth and beauty shall together thrive,
+If from thyself to store thou wouldst convert;
+Or else of thee this I prognosticate:
+Thy end is truth's and beauty's doom and date.
+When I consider every thing that grows
+Holds in perfection but a little moment,
+That this huge stage presenteth nought but shows
+Whereon the stars in secret influence comment;
+When I perceive that men as plants increase,
+Cheered and cheque'd even by the self-same sky,
+Vaunt in their youthful sap, at height decrease,
+And wear their brave state out of memory;
+Then the conceit of this inconstant stay
+Sets you most rich in youth before my sight,
+Where wasteful Time debateth with Decay,
+To change your day of youth to sullied night;
+And all in war with Time for love of you,
+As he takes from you, I engraft you new.
+But wherefore do not you a mightier way
+Make war upon this bloody tyrant, Time?
+And fortify yourself in your decay
+With means more blessed than my barren rhyme?
+Now stand you on the top of happy hours,
+And many maiden gardens yet unset
+With virtuous wish would bear your living flowers,
+Much liker than your painted counterfeit:
+So should the lines of life that life repair,
+Which this, Time's pencil, or my pupil pen,
+Neither in inward worth nor outward fair,
+Can make you live yourself in eyes of men.
+To give away yourself keeps yourself still,
+And you must live, drawn by your own sweet skill.
+Who will believe my verse in time to come,
+If it were fill'd with your most high deserts?
+Though yet, heaven knows, it is but as a tomb
+Which hides your life and shows not half your parts.
+If I could write the beauty of your eyes
+And in fresh numbers number all your graces,
+The age to come would say 'This poet lies:
+Such heavenly touches ne'er touch'd earthly faces.'
+So should my papers yellow'd with their age
+Be scorn'd like old men of less truth than tongue,
+And your true rights be term'd a poet's rage
+And stretched metre of an antique song:
+But were some child of yours alive that time,
+You should live twice; in it and in my rhyme.
+Shall I compare thee to a summer's day?
+Thou art more lovely and more temperate:
+Rough winds do shake the darling buds of May,
+And summer's lease hath all too short a date:
+Sometime too hot the eye of heaven shines,
+And often is his gold complexion dimm'd;
+And every fair from fair sometime declines,
+By chance or nature's changing course untrimm'd;
+But thy eternal summer shall not fade
+Nor lose possession of that fair thou owest;
+Nor shall Death brag thou wander'st in his shade,
+When in eternal lines to time thou growest:
+So long as men can breathe or eyes can see,
+So long lives this and this gives life to thee.
+Devouring Time, blunt thou the lion's paws,
+And make the earth devour her own sweet brood;
+Pluck the keen teeth from the fierce tiger's jaws,
+And burn the long-lived phoenix in her blood;
+Make glad and sorry seasons as thou fleets,
+And do whate'er thou wilt, swift-footed Time,
+To the wide world and all her fading sweets;
+But I forbid thee one most heinous crime:
+O, carve not with thy hours my love's fair brow,
+Nor draw no lines there with thine antique pen;
+Him in thy course untainted do allow
+For beauty's pattern to succeeding men.
+Yet, do thy worst, old Time: despite thy wrong,
+My love shall in my verse ever live young.
+A woman's face with Nature's own hand painted
+Hast thou, the master-mistress of my passion;
+A woman's gentle heart, but not acquainted
+With shifting change, as is false women's fashion;
+An eye more bright than theirs, less false in rolling,
+Gilding the object whereupon it gazeth;
+A man in hue, all 'hues' in his controlling,
+Much steals men's eyes and women's souls amazeth.
+And for a woman wert thou first created;
+Till Nature, as she wrought thee, fell a-doting,
+And by addition me of thee defeated,
+By adding one thing to my purpose nothing.
+But since she prick'd thee out for women's pleasure,
+Mine be thy love and thy love's use their treasure.
+So is it not with me as with that Muse
+Stirr'd by a painted beauty to his verse,
+Who heaven itself for ornament doth use
+And every fair with his fair doth rehearse
+Making a couplement of proud compare,
+With sun and moon, with earth and sea's rich gems,
+With April's first-born flowers, and all things rare
+That heaven's air in this huge rondure hems.
+O' let me, true in love, but truly write,
+And then believe me, my love is as fair
+As any mother's child, though not so bright
+As those gold candles fix'd in heaven's air:
+Let them say more than like of hearsay well;
+I will not praise that purpose not to sell.
+My glass shall not persuade me I am old,
+So long as youth and thou are of one date;
+But when in thee time's furrows I behold,
+Then look I death my days should expiate.
+For all that beauty that doth cover thee
+Is but the seemly raiment of my heart,
+Which in thy breast doth live, as thine in me:
+How can I then be elder than thou art?
+O, therefore, love, be of thyself so wary
+As I, not for myself, but for thee will;
+Bearing thy heart, which I will keep so chary
+As tender nurse her babe from faring ill.
+Presume not on thy heart when mine is slain;
+Thou gavest me thine, not to give back again.
+As an unperfect actor on the stage
+Who with his fear is put besides his part,
+Or some fierce thing replete with too much rage,
+Whose strength's abundance weakens his own heart.
+So I, for fear of trust, forget to say
+The perfect ceremony of love's rite,
+And in mine own love's strength seem to decay,
+O'ercharged with burden of mine own love's might.
+O, let my books be then the eloquence
+And dumb presagers of my speaking breast,
+Who plead for love and look for recompense
+More than that tongue that more hath more express'd.
+O, learn to read what silent love hath writ:
+To hear with eyes belongs to love's fine wit.
+Mine eye hath play'd the painter and hath stell'd
+Thy beauty's form in table of my heart;
+My body is the frame wherein 'tis held,
+And perspective it is the painter's art.
+For through the painter must you see his skill,
+To find where your true image pictured lies;
+Which in my bosom's shop is hanging still,
+That hath his windows glazed with thine eyes.
+Now see what good turns eyes for eyes have done:
+Mine eyes have drawn thy shape, and thine for me
+Are windows to my breast, where-through the sun
+Delights to peep, to gaze therein on thee;
+Yet eyes this cunning want to grace their art;
+They draw but what they see, know not the heart.
+Let those who are in favour with their stars
+Of public honour and proud titles boast,
+Whilst I, whom fortune of such triumph bars,
+Unlook'd for joy in that I honour most.
+Great princes' favourites their fair leaves spread
+But as the marigold at the sun's eye,
+And in themselves their pride lies buried,
+For at a frown they in their glory die.
+The painful warrior famoused for fight,
+After a thousand victories once foil'd,
+Is from the book of honour razed quite,
+And all the rest forgot for which he toil'd:
+Then happy I, that love and am beloved
+Where I may not remove nor be removed.
+Lord of my love, to whom in vassalage
+Thy merit hath my duty strongly knit,
+To thee I send this written embassage,
+To witness duty, not to show my wit:
+Duty so great, which wit so poor as mine
+May make seem bare, in wanting words to show it,
+But that I hope some good conceit of thine
+In thy soul's thought, all naked, will bestow it;
+Till whatsoever star that guides my moving
+Points on me graciously with fair aspect
+And puts apparel on my tatter'd loving,
+To show me worthy of thy sweet respect:
+Then may I dare to boast how I do love thee;
+Till then not show my head where thou mayst prove me.
+Weary with toil, I haste me to my bed,
+The dear repose for limbs with travel tired;
+But then begins a journey in my head,
+To work my mind, when body's work's expired:
+For then my thoughts, from far where I abide,
+Intend a zealous pilgrimage to thee,
+And keep my drooping eyelids open wide,
+Looking on darkness which the blind do see
+Save that my soul's imaginary sight
+Presents thy shadow to my sightless view,
+Which, like a jewel hung in ghastly night,
+Makes black night beauteous and her old face new.
+Lo! thus, by day my limbs, by night my mind,
+For thee and for myself no quiet find.
+How can I then return in happy plight,
+That am debarr'd the benefit of rest?
+When day's oppression is not eased by night,
+But day by night, and night by day, oppress'd?
+And each, though enemies to either's reign,
+Do in consent shake hands to torture me;
+The one by toil, the other to complain
+How far I toil, still farther off from thee.
+I tell the day, to please them thou art bright
+And dost him grace when clouds do blot the heaven:
+So flatter I the swart-complexion'd night,
+When sparkling stars twire not thou gild'st the even.
+But day doth daily draw my sorrows longer
+And night doth nightly make grief's strength seem stronger.
+When, in disgrace with fortune and men's eyes,
+I all alone beweep my outcast state
+And trouble deal heaven with my bootless cries
+And look upon myself and curse my fate,
+Wishing me like to one more rich in hope,
+Featured like him, like him with friends possess'd,
+Desiring this man's art and that man's scope,
+With what I most enjoy contented least;
+Yet in these thoughts myself almost despising,
+Haply I think on thee, and then my state,
+Like to the lark at break of day arising
+From sullen earth, sings hymns at heaven's gate;
+For thy sweet love remember'd such wealth brings
+That then I scorn to change my state with kings.
+When to the sessions of sweet silent thought
+I summon up remembrance of things past,
+I sigh the lack of many a thing I sought,
+And with old woes new wail my dear time's waste:
+Then can I drown an eye, unused to flow,
+For precious friends hid in death's dateless night,
+And weep afresh love's long since cancell'd woe,
+And moan the expense of many a vanish'd sight:
+Then can I grieve at grievances foregone,
+And heavily from woe to woe tell o'er
+The sad account of fore-bemoaned moan,
+Which I new pay as if not paid before.
+But if the while I think on thee, dear friend,
+All losses are restored and sorrows end.
+Thy bosom is endeared with all hearts,
+Which I by lacking have supposed dead,
+And there reigns love and all love's loving parts,
+And all those friends which I thought buried.
+How many a holy and obsequious tear
+Hath dear religious love stol'n from mine eye
+As interest of the dead, which now appear
+But things removed that hidden in thee lie!
+Thou art the grave where buried love doth live,
+Hung with the trophies of my lovers gone,
+Who all their parts of me to thee did give;
+That due of many now is thine alone:
+Their images I loved I view in thee,
+And thou, all they, hast all the all of me.
+If thou survive my well-contented day,
+When that churl Death my bones with dust shall cover,
+And shalt by fortune once more re-survey
+These poor rude lines of thy deceased lover,
+Compare them with the bettering of the time,
+And though they be outstripp'd by every pen,
+Reserve them for my love, not for their rhyme,
+Exceeded by the height of happier men.
+O, then vouchsafe me but this loving thought:
+'Had my friend's Muse grown with this growing age,
+A dearer birth than this his love had brought,
+To march in ranks of better equipage:
+But since he died and poets better prove,
+Theirs for their style I'll read, his for his love.'
+Full many a glorious morning have I seen
+Flatter the mountain-tops with sovereign eye,
+Kissing with golden face the meadows green,
+Gilding pale streams with heavenly alchemy;
+Anon permit the basest clouds to ride
+With ugly rack on his celestial face,
+And from the forlorn world his visage hide,
+Stealing unseen to west with this disgrace:
+Even so my sun one early morn did shine
+With all triumphant splendor on my brow;
+But out, alack! he was but one hour mine;
+The region cloud hath mask'd him from me now.
+Yet him for this my love no whit disdaineth;
+Suns of the world may stain when heaven's sun staineth.
+Why didst thou promise such a beauteous day,
+And make me travel forth without my cloak,
+To let base clouds o'ertake me in my way,
+Hiding thy bravery in their rotten smoke?
+'Tis not enough that through the cloud thou break,
+To dry the rain on my storm-beaten face,
+For no man well of such a salve can speak
+That heals the wound and cures not the disgrace:
+Nor can thy shame give physic to my grief;
+Though thou repent, yet I have still the loss:
+The offender's sorrow lends but weak relief
+To him that bears the strong offence's cross.
+Ah! but those tears are pearl which thy love sheds,
+And they are rich and ransom all ill deeds.
+No more be grieved at that which thou hast done:
+Roses have thorns, and silver fountains mud;
+Clouds and eclipses stain both moon and sun,
+And loathsome canker lives in sweetest bud.
+All men make faults, and even I in this,
+Authorizing thy trespass with compare,
+Myself corrupting, salving thy amiss,
+Excusing thy sins more than thy sins are;
+For to thy sensual fault I bring in sense--
+Thy adverse party is thy advocate--
+And 'gainst myself a lawful plea commence:
+Such civil war is in my love and hate
+That I an accessary needs must be
+To that sweet thief which sourly robs from me.
+Let me confess that we two must be twain,
+Although our undivided loves are one:
+So shall those blots that do with me remain
+Without thy help by me be borne alone.
+In our two loves there is but one respect,
+Though in our lives a separable spite,
+Which though it alter not love's sole effect,
+Yet doth it steal sweet hours from love's delight.
+I may not evermore acknowledge thee,
+Lest my bewailed guilt should do thee shame,
+Nor thou with public kindness honour me,
+Unless thou take that honour from thy name:
+But do not so; I love thee in such sort
+As, thou being mine, mine is thy good report.
+As a decrepit father takes delight
+To see his active child do deeds of youth,
+So I, made lame by fortune's dearest spite,
+Take all my comfort of thy worth and truth.
+For whether beauty, birth, or wealth, or wit,
+Or any of these all, or all, or more,
+Entitled in thy parts do crowned sit,
+I make my love engrafted to this store:
+So then I am not lame, poor, nor despised,
+Whilst that this shadow doth such substance give
+That I in thy abundance am sufficed
+And by a part of all thy glory live.
+Look, what is best, that best I wish in thee:
+This wish I have; then ten times happy me!
\ No newline at end of file
diff --git a/build_musa.sh b/build_musa.sh
new file mode 100644
index 0000000..b831c83
--- /dev/null
+++ b/build_musa.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+
+set -x
+set -e
+
+pip install -r requirements-build.txt
+pip install -r requirements-musa.txt
+
+export VLLM_TARGET_DEVICE=musa
+export CMAKE_BUILD_TYPE=Debug
+export VERBOSE=1
+export VLLM_ATTENTION_BACKEND=FLASH_ATTN
+
+rm -rf build
+rm -rf dist
+rm -rf vllm.egg-info
+pip uninstall -y vllm
+
+python setup.py bdist_wheel
+pip install dist/*
\ No newline at end of file
diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake
new file mode 100644
index 0000000..0cf3776
--- /dev/null
+++ b/cmake/cpu_extension.cmake
@@ -0,0 +1,90 @@
+set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
+
+#
+# Define environment variables for special configurations
+#
+if(DEFINED ENV{VLLM_CPU_AVX512BF16})
+ set(ENABLE_AVX512BF16 ON)
+endif()
+
+include_directories("${CMAKE_SOURCE_DIR}/csrc")
+
+#
+# Check the compile flags
+#
+list(APPEND CXX_COMPILE_FLAGS
+ "-fopenmp"
+ "-DVLLM_CPU_EXTENSION")
+
+execute_process(COMMAND cat /proc/cpuinfo
+ RESULT_VARIABLE CPUINFO_RET
+ OUTPUT_VARIABLE CPUINFO)
+
+if (NOT CPUINFO_RET EQUAL 0)
+ message(FATAL_ERROR "Failed to check CPU features via /proc/cpuinfo")
+endif()
+
+function (find_isa CPUINFO TARGET OUT)
+ string(FIND ${CPUINFO} ${TARGET} ISA_FOUND)
+ if(NOT ISA_FOUND EQUAL -1)
+ set(${OUT} ON PARENT_SCOPE)
+ else()
+ set(${OUT} OFF PARENT_SCOPE)
+ endif()
+endfunction()
+
+find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
+
+if (AVX512_FOUND)
+ list(APPEND CXX_COMPILE_FLAGS
+ "-mavx512f"
+ "-mavx512vl"
+ "-mavx512bw"
+ "-mavx512dq")
+
+ find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
+ if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
+ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
+ CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
+ list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
+ else()
+ message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
+ endif()
+ else()
+ message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
+ endif()
+else()
+ message(FATAL_ERROR "vLLM CPU backend requires AVX512 ISA support.")
+endif()
+
+message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
+
+
+#
+# Define extension targets
+#
+
+#
+# _C extension
+#
+set(VLLM_EXT_SRC
+ "csrc/cpu/activation.cpp"
+ "csrc/cpu/attention.cpp"
+ "csrc/cpu/cache.cpp"
+ "csrc/cpu/layernorm.cpp"
+ "csrc/cpu/pos_encoding.cpp"
+ "csrc/cpu/pybind.cpp")
+
+define_gpu_extension_target(
+ _C
+ DESTINATION vllm
+ LANGUAGE CXX
+ SOURCES ${VLLM_EXT_SRC}
+ COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
+ WITH_SOABI
+)
+
+add_custom_target(default)
+message(STATUS "Enabling C extension.")
+add_dependencies(default _C)
+
diff --git a/cmake/hipify.py b/cmake/hipify.py
new file mode 100755
index 0000000..340e41c
--- /dev/null
+++ b/cmake/hipify.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+#
+# A command line tool for running pytorch's hipify preprocessor on CUDA
+# source files.
+#
+# See https://github.com/ROCm/hipify_torch
+# and /utils/hipify/hipify_python.py
+#
+
+import argparse
+import os
+import shutil
+
+from torch.utils.hipify.hipify_python import hipify
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ # Project directory where all the source + include files live.
+ parser.add_argument(
+ "-p",
+ "--project_dir",
+ help="The project directory.",
+ )
+
+ # Directory where hipified files are written.
+ parser.add_argument(
+ "-o",
+ "--output_dir",
+ help="The output directory.",
+ )
+
+ # Source files to convert.
+ parser.add_argument("sources",
+ help="Source files to hipify.",
+ nargs="*",
+ default=[])
+
+ args = parser.parse_args()
+
+ # Limit include scope to project_dir only
+ includes = [os.path.join(args.project_dir, '*')]
+
+ # Get absolute path for all source files.
+ extra_files = [os.path.abspath(s) for s in args.sources]
+
+ # Copy sources from project directory to output directory.
+ # The directory might already exist to hold object files so we ignore that.
+ shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
+
+ hipify_result = hipify(project_directory=args.project_dir,
+ output_directory=args.output_dir,
+ header_include_dirs=[],
+ includes=includes,
+ extra_files=extra_files,
+ show_detailed=True,
+ is_pytorch_extension=True,
+ hipify_extra_files_only=True)
+
+ hipified_sources = []
+ for source in args.sources:
+ s_abs = os.path.abspath(source)
+ hipified_s_abs = (hipify_result[s_abs].hipified_path if
+ (s_abs in hipify_result
+ and hipify_result[s_abs].hipified_path is not None)
+ else s_abs)
+ hipified_sources.append(hipified_s_abs)
+
+ assert (len(hipified_sources) == len(args.sources))
+
+ # Print hipified source files.
+ print("\n".join(hipified_sources))
diff --git a/cmake/utils.cmake b/cmake/utils.cmake
new file mode 100644
index 0000000..7c71673
--- /dev/null
+++ b/cmake/utils.cmake
@@ -0,0 +1,354 @@
+#
+# Attempt to find the python package that uses the same python executable as
+# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
+#
+macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
+ file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
+ set(Python_EXECUTABLE ${EXECUTABLE})
+ find_package(Python COMPONENTS Interpreter Development.Module)
+ if (NOT Python_FOUND)
+ message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
+ endif()
+ set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
+ set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
+ if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
+ message(FATAL_ERROR
+ "Python version (${_VER}) is not one of the supported versions: "
+ "${_SUPPORTED_VERSIONS_LIST}.")
+ endif()
+ message(STATUS "Found python matching: ${EXECUTABLE}.")
+endmacro()
+
+#
+# Run `EXPR` in python. The standard output of python is stored in `OUT` and
+# has trailing whitespace stripped. If an error is encountered when running
+# python, a fatal message `ERR_MSG` is issued.
+#
+function (run_python OUT EXPR ERR_MSG)
+ execute_process(
+ COMMAND
+ "${Python_EXECUTABLE}" "-c" "${EXPR}"
+ OUTPUT_VARIABLE PYTHON_OUT
+ RESULT_VARIABLE PYTHON_ERROR_CODE
+ ERROR_VARIABLE PYTHON_STDERR
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+
+ if(NOT PYTHON_ERROR_CODE EQUAL 0)
+ message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
+ endif()
+ set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
+endfunction()
+
+# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
+# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
+macro (append_cmake_prefix_path PKG EXPR)
+ run_python(_PREFIX_PATH
+ "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
+ list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
+endmacro()
+
+#
+# Add a target named `hipify${NAME}` that runs the hipify preprocessor on a set
+# of CUDA source files. The names of the corresponding "hipified" sources are
+# stored in `OUT_SRCS`.
+#
+function (hipify_sources_target OUT_SRCS NAME ORIG_SRCS)
+ #
+ # Split into C++ and non-C++ (i.e. CUDA) sources.
+ #
+ set(SRCS ${ORIG_SRCS})
+ set(CXX_SRCS ${ORIG_SRCS})
+ list(FILTER SRCS EXCLUDE REGEX "\.(cc)|(cpp)$")
+ list(FILTER CXX_SRCS INCLUDE REGEX "\.(cc)|(cpp)$")
+
+ #
+ # Generate ROCm/HIP source file names from CUDA file names.
+ # Since HIP files are generated code, they will appear in the build area
+ # `CMAKE_CURRENT_BINARY_DIR` directory rather than the original csrc dir.
+ #
+ set(HIP_SRCS)
+ foreach (SRC ${SRCS})
+ string(REGEX REPLACE "\.cu$" "\.hip" SRC ${SRC})
+ string(REGEX REPLACE "cuda" "hip" SRC ${SRC})
+ list(APPEND HIP_SRCS "${CMAKE_CURRENT_BINARY_DIR}/${SRC}")
+ endforeach()
+
+ set(CSRC_BUILD_DIR ${CMAKE_CURRENT_BINARY_DIR}/csrc)
+ add_custom_target(
+ hipify${NAME}
+ COMMAND ${CMAKE_SOURCE_DIR}/cmake/hipify.py -p ${CMAKE_SOURCE_DIR}/csrc -o ${CSRC_BUILD_DIR} ${SRCS}
+ DEPENDS ${CMAKE_SOURCE_DIR}/cmake/hipify.py ${SRCS}
+ BYPRODUCTS ${HIP_SRCS}
+ COMMENT "Running hipify on ${NAME} extension source files.")
+
+ # Swap out original extension sources with hipified sources.
+ list(APPEND HIP_SRCS ${CXX_SRCS})
+ set(${OUT_SRCS} ${HIP_SRCS} PARENT_SCOPE)
+endfunction()
+
+#
+# Get additional GPU compiler flags from torch.
+#
+function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
+ if (${GPU_LANG} STREQUAL "CUDA")
+ #
+ # Get common NVCC flags from torch.
+ #
+ run_python(GPU_FLAGS
+ "from torch.utils.cpp_extension import COMMON_NVCC_FLAGS; print(';'.join(COMMON_NVCC_FLAGS))"
+ "Failed to determine torch nvcc compiler flags")
+
+ if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8)
+ list(APPEND GPU_FLAGS "-DENABLE_FP8_E5M2")
+ endif()
+ if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0)
+ list(REMOVE_ITEM GPU_FLAGS
+ "-D__CUDA_NO_HALF_OPERATORS__"
+ "-D__CUDA_NO_HALF_CONVERSIONS__"
+ "-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
+ "-D__CUDA_NO_HALF2_OPERATORS__")
+ endif()
+
+ elseif(${GPU_LANG} STREQUAL "HIP")
+ #
+ # Get common HIP/HIPCC flags from torch.
+ #
+ run_python(GPU_FLAGS
+ "import torch.utils.cpp_extension as t; print(';'.join(t.COMMON_HIP_FLAGS + t.COMMON_HIPCC_FLAGS))"
+ "Failed to determine torch nvcc compiler flags")
+
+ list(APPEND GPU_FLAGS
+ "-DUSE_ROCM"
+ "-DENABLE_FP8_E4M3"
+ "-U__HIP_NO_HALF_CONVERSIONS__"
+ "-U__HIP_NO_HALF_OPERATORS__"
+ "-fno-gpu-rdc")
+
+ endif()
+ set(${OUT_GPU_FLAGS} ${GPU_FLAGS} PARENT_SCOPE)
+endfunction()
+
+# Macro for converting a `gencode` version number to a cmake version number.
+macro(string_to_ver OUT_VER IN_STR)
+ string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
+endmacro()
+
+#
+# Override the GPU architectures detected by cmake/torch and filter them by
+# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
+# `GPU_ARCHES`.
+#
+# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
+#
+macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
+ set(_GPU_SUPPORTED_ARCHES_LIST ${GPU_SUPPORTED_ARCHES} ${ARGN})
+ message(STATUS "${GPU_LANG} supported arches: ${_GPU_SUPPORTED_ARCHES_LIST}")
+
+ if (${GPU_LANG} STREQUAL "HIP")
+ #
+ # `GPU_ARCHES` controls the `--offload-arch` flags.
+ # `CMAKE_HIP_ARCHITECTURES` is set up by torch and can be controlled
+ # via the `PYTORCH_ROCM_ARCH` env variable.
+ #
+
+ #
+ # Find the intersection of the supported + detected architectures to
+ # set the module architecture flags.
+ #
+ set(${GPU_ARCHES})
+ foreach (_ARCH ${CMAKE_HIP_ARCHITECTURES})
+ if (_ARCH IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
+ list(APPEND ${GPU_ARCHES} ${_ARCH})
+ endif()
+ endforeach()
+
+ if(NOT ${GPU_ARCHES})
+ message(FATAL_ERROR
+ "None of the detected ROCm architectures: ${CMAKE_HIP_ARCHITECTURES} is"
+ " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
+ endif()
+
+ elseif(${GPU_LANG} STREQUAL "CUDA")
+ #
+ # Setup/process CUDA arch flags.
+ #
+ # The torch cmake setup hardcodes the detected architecture flags in
+ # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
+ # can't modified on a per-target basis, e.g. for the `punica` extension.
+ # So, all the `-gencode` flags need to be extracted and removed from
+ # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
+ # Since it's not possible to use `target_compiler_options` for adding target
+ # specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
+ # must be used instead. This requires repackaging the architecture flags
+ # into a format that cmake expects for `CUDA_ARCHITECTURES`.
+ #
+ # This is a bit fragile in that it depends on torch using `-gencode` as opposed
+ # to one of the other nvcc options to specify architectures.
+ #
+ # Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
+ # detected architectures.
+ #
+ message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
+
+ # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
+ string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
+ ${CMAKE_CUDA_FLAGS})
+
+ # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
+ # and passed back via the `CUDA_ARCHITECTURES` property.
+ string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
+ ${CMAKE_CUDA_FLAGS})
+
+ # If this error is triggered, it might mean that torch has changed how it sets
+ # up nvcc architecture code generation flags.
+ if (NOT _CUDA_ARCH_FLAGS)
+ message(FATAL_ERROR
+ "Could not find any architecture related code generation flags in "
+ "CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
+ endif()
+
+ message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
+ message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
+
+ # Initialize the architecture lists to empty.
+ set(${GPU_ARCHES})
+
+ # Process each `gencode` flag.
+ foreach(_ARCH ${_CUDA_ARCH_FLAGS})
+ # For each flag, extract the version number and whether it refers to PTX
+ # or native code.
+ # Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
+ # for that match.
+
+ string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
+ if (_COMPUTE)
+ set(_COMPUTE ${CMAKE_MATCH_1})
+ endif()
+
+ string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
+ if (_SM)
+ set(_SM ${CMAKE_MATCH_1})
+ endif()
+
+ string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
+ if (_CODE)
+ set(_CODE ${CMAKE_MATCH_1})
+ endif()
+
+ # Make sure the virtual architecture can be matched.
+ if (NOT _COMPUTE)
+ message(FATAL_ERROR
+ "Could not determine virtual architecture from: ${_ARCH}.")
+ endif()
+
+ # One of sm_ or compute_ must exist.
+ if ((NOT _SM) AND (NOT _CODE))
+ message(FATAL_ERROR
+ "Could not determine a codegen architecture from: ${_ARCH}.")
+ endif()
+
+ if (_SM)
+ # -real suffix let CMake to only generate elf code for the kernels.
+ # we want this, otherwise the added ptx (default) will increase binary size.
+ set(_VIRT "-real")
+ set(_CODE_ARCH ${_SM})
+ else()
+ # -virtual suffix let CMake to generate ptx code for the kernels.
+ set(_VIRT "-virtual")
+ set(_CODE_ARCH ${_CODE})
+ endif()
+
+ # Check if the current version is in the supported arch list.
+ string_to_ver(_CODE_VER ${_CODE_ARCH})
+ if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
+ message(STATUS "discarding unsupported CUDA arch ${_VER}.")
+ continue()
+ endif()
+
+ # Add it to the arch list.
+ list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
+ endforeach()
+ endif()
+ message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
+endmacro()
+
+#
+# Define a target named `GPU_MOD_NAME` for a single extension. The
+# arguments are:
+#
+# DESTINATION - Module destination directory.
+# LANGUAGE - The GPU language for this module, e.g CUDA, HIP,
+# etc.
+# SOURCES - List of source files relative to CMakeLists.txt
+# directory.
+#
+# Optional arguments:
+#
+# ARCHITECTURES - A list of target GPU architectures in cmake
+# format.
+# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
+# and `CMAKE_HIP_ARCHITECTURES` for more info.
+# ARCHITECTURES will use cmake's defaults if
+# not provided.
+# COMPILE_FLAGS - Extra compiler flags passed to NVCC/hip.
+# INCLUDE_DIRECTORIES - Extra include directories.
+# LIBRARIES - Extra link libraries.
+# WITH_SOABI - Generate library with python SOABI suffix name.
+#
+# Note: optimization level/debug info is set via cmake build type.
+#
+function (define_gpu_extension_target GPU_MOD_NAME)
+ cmake_parse_arguments(PARSE_ARGV 1
+ GPU
+ "WITH_SOABI"
+ "DESTINATION;LANGUAGE"
+ "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
+
+ # Add hipify preprocessing step when building with HIP/ROCm.
+ if (GPU_LANGUAGE STREQUAL "HIP")
+ hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
+ endif()
+
+ if (GPU_WITH_SOABI)
+ set(GPU_WITH_SOABI WITH_SOABI)
+ else()
+ set(GPU_WITH_SOABI)
+ endif()
+
+ Python_add_library(${GPU_MOD_NAME} MODULE "${GPU_SOURCES}" ${GPU_WITH_SOABI})
+
+ if (GPU_LANGUAGE STREQUAL "HIP")
+ # Make this target dependent on the hipify preprocessor step.
+ add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
+ endif()
+
+ if (GPU_ARCHITECTURES)
+ set_target_properties(${GPU_MOD_NAME} PROPERTIES
+ ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
+ endif()
+
+ set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
+
+ target_compile_options(${GPU_MOD_NAME} PRIVATE
+ $<$:${GPU_COMPILE_FLAGS}>)
+
+ target_compile_definitions(${GPU_MOD_NAME} PRIVATE
+ "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
+
+ target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
+ ${GPU_INCLUDE_DIRECTORIES})
+
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY}
+ ${GPU_LIBRARIES})
+
+ # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
+ # dependencies that are not necessary and may not be installed.
+ if (GPU_LANGUAGE STREQUAL "CUDA")
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
+ ${CUDA_LIBRARIES})
+ else()
+ target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
+ endif()
+
+ install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION})
+endfunction()
diff --git a/collect_env.py b/collect_env.py
new file mode 100644
index 0000000..1ecfeb8
--- /dev/null
+++ b/collect_env.py
@@ -0,0 +1,721 @@
+# ruff: noqa
+# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
+
+# Unlike the rest of the PyTorch this file must be python2 compliant.
+# This script outputs relevant system environment info
+# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
+import datetime
+import locale
+import os
+import re
+import subprocess
+import sys
+from collections import namedtuple
+
+try:
+ import torch
+ TORCH_AVAILABLE = True
+except (ImportError, NameError, AttributeError, OSError):
+ TORCH_AVAILABLE = False
+
+# System Environment Information
+SystemEnv = namedtuple(
+ 'SystemEnv',
+ [
+ 'torch_version',
+ 'is_debug_build',
+ 'cuda_compiled_version',
+ 'gcc_version',
+ 'clang_version',
+ 'cmake_version',
+ 'os',
+ 'libc_version',
+ 'python_version',
+ 'python_platform',
+ 'is_cuda_available',
+ 'cuda_runtime_version',
+ 'cuda_module_loading',
+ 'nvidia_driver_version',
+ 'nvidia_gpu_models',
+ 'cudnn_version',
+ 'pip_version', # 'pip' or 'pip3'
+ 'pip_packages',
+ 'conda_packages',
+ 'hip_compiled_version',
+ 'hip_runtime_version',
+ 'miopen_runtime_version',
+ 'caching_allocator_config',
+ 'is_xnnpack_available',
+ 'cpu_info',
+ 'rocm_version', # vllm specific field
+ 'neuron_sdk_version', # vllm specific field
+ 'vllm_version', # vllm specific field
+ 'vllm_build_flags', # vllm specific field
+ 'gpu_topo', # vllm specific field
+ ])
+
+DEFAULT_CONDA_PATTERNS = {
+ "torch",
+ "numpy",
+ "cudatoolkit",
+ "soumith",
+ "mkl",
+ "magma",
+ "triton",
+ "optree",
+ "nccl",
+}
+
+DEFAULT_PIP_PATTERNS = {
+ "torch",
+ "numpy",
+ "mypy",
+ "flake8",
+ "triton",
+ "optree",
+ "onnx",
+ "nccl",
+}
+
+
+def run(command):
+ """Return (return-code, stdout, stderr)."""
+ shell = True if type(command) is str else False
+ p = subprocess.Popen(command,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=shell)
+ raw_output, raw_err = p.communicate()
+ rc = p.returncode
+ if get_platform() == 'win32':
+ enc = 'oem'
+ else:
+ enc = locale.getpreferredencoding()
+ output = raw_output.decode(enc)
+ err = raw_err.decode(enc)
+ return rc, output.strip(), err.strip()
+
+
+def run_and_read_all(run_lambda, command):
+ """Run command using run_lambda; reads and returns entire output if rc is 0."""
+ rc, out, _ = run_lambda(command)
+ if rc != 0:
+ return None
+ return out
+
+
+def run_and_parse_first_match(run_lambda, command, regex):
+ """Run command using run_lambda, returns the first regex match if it exists."""
+ rc, out, _ = run_lambda(command)
+ if rc != 0:
+ return None
+ match = re.search(regex, out)
+ if match is None:
+ return None
+ return match.group(1)
+
+
+def run_and_return_first_line(run_lambda, command):
+ """Run command using run_lambda and returns first line if output is not empty."""
+ rc, out, _ = run_lambda(command)
+ if rc != 0:
+ return None
+ return out.split('\n')[0]
+
+
+def get_conda_packages(run_lambda, patterns=None):
+ if patterns is None:
+ patterns = DEFAULT_CONDA_PATTERNS
+ conda = os.environ.get('CONDA_EXE', 'conda')
+ out = run_and_read_all(run_lambda, "{} list".format(conda))
+ if out is None:
+ return out
+
+ return "\n".join(line for line in out.splitlines()
+ if not line.startswith("#") and any(name in line
+ for name in patterns))
+
+
+def get_gcc_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
+
+
+def get_clang_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'clang --version',
+ r'clang version (.*)')
+
+
+def get_cmake_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'cmake --version',
+ r'cmake (.*)')
+
+
+def get_nvidia_driver_version(run_lambda):
+ if get_platform() == 'darwin':
+ cmd = 'kextstat | grep -i cuda'
+ return run_and_parse_first_match(run_lambda, cmd,
+ r'com[.]nvidia[.]CUDA [(](.*?)[)]')
+ smi = get_nvidia_smi()
+ return run_and_parse_first_match(run_lambda, smi,
+ r'Driver Version: (.*?) ')
+
+
+def get_gpu_info(run_lambda):
+ if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(
+ torch.version, 'hip') and torch.version.hip is not None):
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ if torch.version.hip is not None:
+ prop = torch.cuda.get_device_properties(0)
+ if hasattr(prop, "gcnArchName"):
+ gcnArch = " ({})".format(prop.gcnArchName)
+ else:
+ gcnArch = "NoGCNArchNameOnOldPyTorch"
+ else:
+ gcnArch = ""
+ return torch.cuda.get_device_name(None) + gcnArch
+ return None
+ smi = get_nvidia_smi()
+ uuid_regex = re.compile(r' \(UUID: .+?\)')
+ rc, out, _ = run_lambda(smi + ' -L')
+ if rc != 0:
+ return None
+ # Anonymize GPUs by removing their UUID
+ return re.sub(uuid_regex, '', out)
+
+
+def get_running_cuda_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'nvcc --version',
+ r'release .+ V(.*)')
+
+
+def get_cudnn_version(run_lambda):
+ """Return a list of libcudnn.so; it's hard to tell which one is being used."""
+ if get_platform() == 'win32':
+ system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
+ cuda_path = os.environ.get('CUDA_PATH', "%CUDA_PATH%")
+ where_cmd = os.path.join(system_root, 'System32', 'where')
+ cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
+ elif get_platform() == 'darwin':
+ # CUDA libraries and drivers can be found in /usr/local/cuda/. See
+ # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install
+ # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac
+ # Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
+ cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*'
+ else:
+ cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
+ rc, out, _ = run_lambda(cudnn_cmd)
+ # find will return 1 if there are permission errors or if not found
+ if len(out) == 0 or (rc != 1 and rc != 0):
+ l = os.environ.get('CUDNN_LIBRARY')
+ if l is not None and os.path.isfile(l):
+ return os.path.realpath(l)
+ return None
+ files_set = set()
+ for fn in out.split('\n'):
+ fn = os.path.realpath(fn) # eliminate symbolic links
+ if os.path.isfile(fn):
+ files_set.add(fn)
+ if not files_set:
+ return None
+ # Alphabetize the result because the order is non-deterministic otherwise
+ files = sorted(files_set)
+ if len(files) == 1:
+ return files[0]
+ result = '\n'.join(files)
+ return 'Probably one of the following:\n{}'.format(result)
+
+
+def get_nvidia_smi():
+ # Note: nvidia-smi is currently available only on Windows and Linux
+ smi = 'nvidia-smi'
+ if get_platform() == 'win32':
+ system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
+ program_files_root = os.environ.get('PROGRAMFILES',
+ 'C:\\Program Files')
+ legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation',
+ 'NVSMI', smi)
+ new_path = os.path.join(system_root, 'System32', smi)
+ smis = [new_path, legacy_path]
+ for candidate_smi in smis:
+ if os.path.exists(candidate_smi):
+ smi = '"{}"'.format(candidate_smi)
+ break
+ return smi
+
+
+def get_rocm_version(run_lambda):
+ """Returns the ROCm version if available, otherwise 'N/A'."""
+ return run_and_parse_first_match(run_lambda, 'hipcc --version',
+ r'HIP version: (\S+)')
+
+
+def get_neuron_sdk_version(run_lambda):
+ # Adapted from your install script
+ try:
+ result = run_lambda(["neuron-ls"])
+ return result if result[0] == 0 else 'N/A'
+ except Exception:
+ return 'N/A'
+
+
+def get_vllm_version():
+ try:
+ import vllm
+ return vllm.__version__
+ except ImportError:
+ return 'N/A'
+
+
+def summarize_vllm_build_flags():
+ # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
+ return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
+ os.environ.get('TORCH_CUDA_ARCH_LIST', 'Not Set'),
+ 'Enabled' if os.environ.get('ROCM_HOME') else 'Disabled',
+ 'Enabled' if os.environ.get('NEURON_CORES') else 'Disabled',
+ )
+
+
+def get_gpu_topo(run_lambda):
+ if get_platform() == 'linux':
+ return run_and_read_all(run_lambda, 'nvidia-smi topo -m')
+ return None
+
+
+# example outputs of CPU infos
+# * linux
+# Architecture: x86_64
+# CPU op-mode(s): 32-bit, 64-bit
+# Address sizes: 46 bits physical, 48 bits virtual
+# Byte Order: Little Endian
+# CPU(s): 128
+# On-line CPU(s) list: 0-127
+# Vendor ID: GenuineIntel
+# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
+# CPU family: 6
+# Model: 106
+# Thread(s) per core: 2
+# Core(s) per socket: 32
+# Socket(s): 2
+# Stepping: 6
+# BogoMIPS: 5799.78
+# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
+# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
+# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
+# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
+# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
+# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
+# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
+# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
+# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
+# Virtualization features:
+# Hypervisor vendor: KVM
+# Virtualization type: full
+# Caches (sum of all):
+# L1d: 3 MiB (64 instances)
+# L1i: 2 MiB (64 instances)
+# L2: 80 MiB (64 instances)
+# L3: 108 MiB (2 instances)
+# NUMA:
+# NUMA node(s): 2
+# NUMA node0 CPU(s): 0-31,64-95
+# NUMA node1 CPU(s): 32-63,96-127
+# Vulnerabilities:
+# Itlb multihit: Not affected
+# L1tf: Not affected
+# Mds: Not affected
+# Meltdown: Not affected
+# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
+# Retbleed: Not affected
+# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
+# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
+# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
+# Srbds: Not affected
+# Tsx async abort: Not affected
+# * win32
+# Architecture=9
+# CurrentClockSpeed=2900
+# DeviceID=CPU0
+# Family=179
+# L2CacheSize=40960
+# L2CacheSpeed=
+# Manufacturer=GenuineIntel
+# MaxClockSpeed=2900
+# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
+# ProcessorType=3
+# Revision=27142
+#
+# Architecture=9
+# CurrentClockSpeed=2900
+# DeviceID=CPU1
+# Family=179
+# L2CacheSize=40960
+# L2CacheSpeed=
+# Manufacturer=GenuineIntel
+# MaxClockSpeed=2900
+# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
+# ProcessorType=3
+# Revision=27142
+
+
+def get_cpu_info(run_lambda):
+ rc, out, err = 0, '', ''
+ if get_platform() == 'linux':
+ rc, out, err = run_lambda('lscpu')
+ elif get_platform() == 'win32':
+ rc, out, err = run_lambda(
+ 'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
+ CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE'
+ )
+ elif get_platform() == 'darwin':
+ rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
+ cpu_info = 'None'
+ if rc == 0:
+ cpu_info = out
+ else:
+ cpu_info = err
+ return cpu_info
+
+
+def get_platform():
+ if sys.platform.startswith('linux'):
+ return 'linux'
+ elif sys.platform.startswith('win32'):
+ return 'win32'
+ elif sys.platform.startswith('cygwin'):
+ return 'cygwin'
+ elif sys.platform.startswith('darwin'):
+ return 'darwin'
+ else:
+ return sys.platform
+
+
+def get_mac_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion',
+ r'(.*)')
+
+
+def get_windows_version(run_lambda):
+ system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
+ wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic')
+ findstr_cmd = os.path.join(system_root, 'System32', 'findstr')
+ return run_and_read_all(
+ run_lambda,
+ '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
+
+
+def get_lsb_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'lsb_release -a',
+ r'Description:\t(.*)')
+
+
+def check_release_file(run_lambda):
+ return run_and_parse_first_match(run_lambda, 'cat /etc/*-release',
+ r'PRETTY_NAME="(.*)"')
+
+
+def get_os(run_lambda):
+ from platform import machine
+ platform = get_platform()
+
+ if platform == 'win32' or platform == 'cygwin':
+ return get_windows_version(run_lambda)
+
+ if platform == 'darwin':
+ version = get_mac_version(run_lambda)
+ if version is None:
+ return None
+ return 'macOS {} ({})'.format(version, machine())
+
+ if platform == 'linux':
+ # Ubuntu/Debian based
+ desc = get_lsb_version(run_lambda)
+ if desc is not None:
+ return '{} ({})'.format(desc, machine())
+
+ # Try reading /etc/*-release
+ desc = check_release_file(run_lambda)
+ if desc is not None:
+ return '{} ({})'.format(desc, machine())
+
+ return '{} ({})'.format(platform, machine())
+
+ # Unknown platform
+ return platform
+
+
+def get_python_platform():
+ import platform
+ return platform.platform()
+
+
+def get_libc_version():
+ import platform
+ if get_platform() != 'linux':
+ return 'N/A'
+ return '-'.join(platform.libc_ver())
+
+
+def get_pip_packages(run_lambda, patterns=None):
+ """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
+ if patterns is None:
+ patterns = DEFAULT_PIP_PATTERNS
+
+ # People generally have `pip` as `pip` or `pip3`
+ # But here it is invoked as `python -mpip`
+ def run_with_pip(pip):
+ out = run_and_read_all(run_lambda, pip + ["list", "--format=freeze"])
+ return "\n".join(line for line in out.splitlines()
+ if any(name in line for name in patterns))
+
+ pip_version = 'pip3' if sys.version[0] == '3' else 'pip'
+ out = run_with_pip([sys.executable, '-mpip'])
+
+ return pip_version, out
+
+
+def get_cachingallocator_config():
+ ca_config = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', '')
+ return ca_config
+
+
+def get_cuda_module_loading_config():
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ torch.cuda.init()
+ config = os.environ.get('CUDA_MODULE_LOADING', '')
+ return config
+ else:
+ return "N/A"
+
+
+def is_xnnpack_available():
+ if TORCH_AVAILABLE:
+ import torch.backends.xnnpack
+ return str(
+ torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
+ else:
+ return "N/A"
+
+
+def get_env_info():
+ run_lambda = run
+ pip_version, pip_list_output = get_pip_packages(run_lambda)
+
+ if TORCH_AVAILABLE:
+ version_str = torch.__version__
+ debug_mode_str = str(torch.version.debug)
+ cuda_available_str = str(torch.cuda.is_available())
+ cuda_version_str = torch.version.cuda
+ if not hasattr(torch.version,
+ 'hip') or torch.version.hip is None: # cuda version
+ hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
+ else: # HIP version
+
+ def get_version_or_na(cfg, prefix):
+ _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
+ return _lst[0] if _lst else 'N/A'
+
+ cfg = torch._C._show_config().split('\n')
+ hip_runtime_version = get_version_or_na(cfg, 'HIP Runtime')
+ miopen_runtime_version = get_version_or_na(cfg, 'MIOpen')
+ cuda_version_str = 'N/A'
+ hip_compiled_version = torch.version.hip
+ else:
+ version_str = debug_mode_str = cuda_available_str = cuda_version_str = 'N/A'
+ hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
+
+ sys_version = sys.version.replace("\n", " ")
+
+ conda_packages = get_conda_packages(run_lambda)
+
+ rocm_version = get_rocm_version(run_lambda)
+ neuron_sdk_version = get_neuron_sdk_version(run_lambda)
+ vllm_version = get_vllm_version()
+ vllm_build_flags = summarize_vllm_build_flags()
+ gpu_topo = get_gpu_topo(run_lambda)
+
+ return SystemEnv(
+ torch_version=version_str,
+ is_debug_build=debug_mode_str,
+ python_version='{} ({}-bit runtime)'.format(
+ sys_version,
+ sys.maxsize.bit_length() + 1),
+ python_platform=get_python_platform(),
+ is_cuda_available=cuda_available_str,
+ cuda_compiled_version=cuda_version_str,
+ cuda_runtime_version=get_running_cuda_version(run_lambda),
+ cuda_module_loading=get_cuda_module_loading_config(),
+ nvidia_gpu_models=get_gpu_info(run_lambda),
+ nvidia_driver_version=get_nvidia_driver_version(run_lambda),
+ cudnn_version=get_cudnn_version(run_lambda),
+ hip_compiled_version=hip_compiled_version,
+ hip_runtime_version=hip_runtime_version,
+ miopen_runtime_version=miopen_runtime_version,
+ pip_version=pip_version,
+ pip_packages=pip_list_output,
+ conda_packages=conda_packages,
+ os=get_os(run_lambda),
+ libc_version=get_libc_version(),
+ gcc_version=get_gcc_version(run_lambda),
+ clang_version=get_clang_version(run_lambda),
+ cmake_version=get_cmake_version(run_lambda),
+ caching_allocator_config=get_cachingallocator_config(),
+ is_xnnpack_available=is_xnnpack_available(),
+ cpu_info=get_cpu_info(run_lambda),
+ rocm_version=rocm_version,
+ neuron_sdk_version=neuron_sdk_version,
+ vllm_version=vllm_version,
+ vllm_build_flags=vllm_build_flags,
+ gpu_topo=gpu_topo,
+ )
+
+
+env_info_fmt = """
+PyTorch version: {torch_version}
+Is debug build: {is_debug_build}
+CUDA used to build PyTorch: {cuda_compiled_version}
+ROCM used to build PyTorch: {hip_compiled_version}
+
+OS: {os}
+GCC version: {gcc_version}
+Clang version: {clang_version}
+CMake version: {cmake_version}
+Libc version: {libc_version}
+
+Python version: {python_version}
+Python platform: {python_platform}
+Is CUDA available: {is_cuda_available}
+CUDA runtime version: {cuda_runtime_version}
+CUDA_MODULE_LOADING set to: {cuda_module_loading}
+GPU models and configuration: {nvidia_gpu_models}
+Nvidia driver version: {nvidia_driver_version}
+cuDNN version: {cudnn_version}
+HIP runtime version: {hip_runtime_version}
+MIOpen runtime version: {miopen_runtime_version}
+Is XNNPACK available: {is_xnnpack_available}
+
+CPU:
+{cpu_info}
+
+Versions of relevant libraries:
+{pip_packages}
+{conda_packages}
+""".strip()
+
+env_info_fmt += """
+ROCM Version: {rocm_version}
+Neuron SDK Version: {neuron_sdk_version}
+vLLM Version: {vllm_version}
+vLLM Build Flags:
+{vllm_build_flags}
+GPU Topology:
+{gpu_topo}
+""".strip()
+
+
+def pretty_str(envinfo):
+
+ def replace_nones(dct, replacement='Could not collect'):
+ for key in dct.keys():
+ if dct[key] is not None:
+ continue
+ dct[key] = replacement
+ return dct
+
+ def replace_bools(dct, true='Yes', false='No'):
+ for key in dct.keys():
+ if dct[key] is True:
+ dct[key] = true
+ elif dct[key] is False:
+ dct[key] = false
+ return dct
+
+ def prepend(text, tag='[prepend]'):
+ lines = text.split('\n')
+ updated_lines = [tag + line for line in lines]
+ return '\n'.join(updated_lines)
+
+ def replace_if_empty(text, replacement='No relevant packages'):
+ if text is not None and len(text) == 0:
+ return replacement
+ return text
+
+ def maybe_start_on_next_line(string):
+ # If `string` is multiline, prepend a \n to it.
+ if string is not None and len(string.split('\n')) > 1:
+ return '\n{}\n'.format(string)
+ return string
+
+ mutable_dict = envinfo._asdict()
+
+ # If nvidia_gpu_models is multiline, start on the next line
+ mutable_dict['nvidia_gpu_models'] = \
+ maybe_start_on_next_line(envinfo.nvidia_gpu_models)
+
+ # If the machine doesn't have CUDA, report some fields as 'No CUDA'
+ dynamic_cuda_fields = [
+ 'cuda_runtime_version',
+ 'nvidia_gpu_models',
+ 'nvidia_driver_version',
+ ]
+ all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
+ all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None
+ for field in dynamic_cuda_fields)
+ if TORCH_AVAILABLE and not torch.cuda.is_available(
+ ) and all_dynamic_cuda_fields_missing:
+ for field in all_cuda_fields:
+ mutable_dict[field] = 'No CUDA'
+ if envinfo.cuda_compiled_version is None:
+ mutable_dict['cuda_compiled_version'] = 'None'
+
+ # Replace True with Yes, False with No
+ mutable_dict = replace_bools(mutable_dict)
+
+ # Replace all None objects with 'Could not collect'
+ mutable_dict = replace_nones(mutable_dict)
+
+ # If either of these are '', replace with 'No relevant packages'
+ mutable_dict['pip_packages'] = replace_if_empty(
+ mutable_dict['pip_packages'])
+ mutable_dict['conda_packages'] = replace_if_empty(
+ mutable_dict['conda_packages'])
+
+ # Tag conda and pip packages with a prefix
+ # If they were previously None, they'll show up as ie '[conda] Could not collect'
+ if mutable_dict['pip_packages']:
+ mutable_dict['pip_packages'] = prepend(
+ mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version))
+ if mutable_dict['conda_packages']:
+ mutable_dict['conda_packages'] = prepend(
+ mutable_dict['conda_packages'], '[conda] ')
+ mutable_dict['cpu_info'] = envinfo.cpu_info
+ return env_info_fmt.format(**mutable_dict)
+
+
+def get_pretty_env_info():
+ return pretty_str(get_env_info())
+
+
+def main():
+ print("Collecting environment information...")
+ output = get_pretty_env_info()
+ print(output)
+
+ if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(
+ torch.utils, '_crash_handler'):
+ minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
+ if sys.platform == "linux" and os.path.exists(minidump_dir):
+ dumps = [
+ os.path.join(minidump_dir, dump)
+ for dump in os.listdir(minidump_dir)
+ ]
+ latest = max(dumps, key=os.path.getctime)
+ ctime = os.path.getctime(latest)
+ creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
+ '%Y-%m-%d %H:%M:%S')
+ msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
+ "if this is related to your bug please include it when you file a report ***"
+ print(msg, file=sys.stderr)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/csrc_musa/activation_kernels.mu b/csrc_musa/activation_kernels.mu
new file mode 100644
index 0000000..c21a02d
--- /dev/null
+++ b/csrc_musa/activation_kernels.mu
@@ -0,0 +1,161 @@
+#include "torch_musa/csrc/aten/musa/MUSAContext.h"
+#include
+#include "torch_musa/csrc/core/MUSAGuard.h"
+
+#include
+
+#include "musa_compat.h"
+#include "dispatch_utils.h"
+
+namespace vllm {
+
+// Activation and gating kernel template.
+template
+__global__ void act_and_mul_kernel(
+ scalar_t* __restrict__ out, // [..., d]
+ const scalar_t* __restrict__ input, // [..., 2, d]
+ const int d) {
+ const int64_t token_idx = blockIdx.x;
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
+ const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
+ out[token_idx * d + idx] = ACT_FN(x) * y;
+ }
+}
+
+template
+__device__ __forceinline__ T silu_kernel(const T& x) {
+ // x * sigmoid(x)
+ return (T) (((float) x) / (1.0f + expf((float) -x)));
+}
+
+template
+__device__ __forceinline__ T gelu_kernel(const T& x) {
+ // Equivalent to PyTorch GELU with 'none' approximation.
+ // Refer to:
+ // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
+ const float f = (float) x;
+ constexpr float ALPHA = M_SQRT1_2;
+ return (T) (f * 0.5f * (1.0f + ::erf(f * ALPHA)));
+}
+
+template
+__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
+ // Equivalent to PyTorch GELU with 'tanh' approximation.
+ // Refer to:
+ // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
+ const float f = (float) x;
+ constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
+ constexpr float KAPPA = 0.044715;
+ float x_cube = f * f * f;
+ float inner = BETA * (f + KAPPA * x_cube);
+ return (T) (0.5f * f * (1.0f + ::tanhf(inner)));
+}
+
+} // namespace vllm
+
+// Launch activation and gating kernel.
+#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
+ int d = input.size(-1) / 2; \
+ int64_t num_tokens = input.numel() / input.size(-1); \
+ dim3 grid(num_tokens); \
+ dim3 block(std::min(d, 1024)); \
+ const at::musa::OptionalMUSAGuard device_guard(device_of(input)); \
+ const musaStream_t stream = at::musa::getCurrentMUSAStream(); \
+ VLLM_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), \
+ "act_and_mul_kernel", \
+ [&] { \
+ vllm::act_and_mul_kernel><<>>( \
+ out.data_ptr(), \
+ input.data_ptr(), \
+ d); \
+ });
+
+void silu_and_mul(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., 2 * d]
+{
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
+}
+
+void gelu_and_mul(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., 2 * d]
+{
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
+}
+
+void gelu_tanh_and_mul(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., 2 * d]
+{
+ LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
+}
+
+namespace vllm {
+
+// Element-wise activation kernel template.
+template
+__global__ void activation_kernel(
+ scalar_t* __restrict__ out, // [..., d]
+ const scalar_t* __restrict__ input, // [..., d]
+ const int d) {
+ const int64_t token_idx = blockIdx.x;
+ for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
+ const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
+ out[token_idx * d + idx] = ACT_FN(x);
+ }
+}
+
+} // namespace vllm
+
+// Launch element-wise activation kernel.
+#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
+ int d = input.size(-1); \
+ int64_t num_tokens = input.numel() / d; \
+ dim3 grid(num_tokens); \
+ dim3 block(std::min(d, 1024)); \
+ const at::musa::OptionalMUSAGuard device_guard(device_of(input)); \
+ const musaStream_t stream = at::musa::getCurrentMUSAStream(); \
+ VLLM_DISPATCH_FLOATING_TYPES( \
+ input.scalar_type(), \
+ "activation_kernel", \
+ [&] { \
+ vllm::activation_kernel><<>>( \
+ out.data_ptr(), \
+ input.data_ptr(), \
+ d); \
+ });
+
+namespace vllm {
+
+template
+__device__ __forceinline__ T gelu_new_kernel(const T& x) {
+ const float x3 = (float) (x * x * x);
+ const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
+ return ((T) 0.5) * x * (((T) 1.0) + t);
+}
+
+template
+__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
+ const float f = (float) x;
+ const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
+ return ((T) 0.5) * x * (((T) 1.0) + t);
+}
+
+} // namespace vllm
+
+void gelu_new(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., d]
+{
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
+}
+
+void gelu_fast(
+ torch::Tensor& out, // [..., d]
+ torch::Tensor& input) // [..., d]
+{
+ LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
+}
diff --git a/csrc_musa/attention/attention_dtypes.h b/csrc_musa/attention/attention_dtypes.h
new file mode 100644
index 0000000..02cb8f1
--- /dev/null
+++ b/csrc_musa/attention/attention_dtypes.h
@@ -0,0 +1,7 @@
+#pragma once
+
+#include "attention_generic.muh"
+#include "dtype_float16.muh"
+#include "dtype_float32.muh"
+#include "dtype_bfloat16.muh"
+#include "dtype_fp8.muh"
diff --git a/csrc_musa/attention/attention_generic.muh b/csrc_musa/attention/attention_generic.muh
new file mode 100644
index 0000000..dfec157
--- /dev/null
+++ b/csrc_musa/attention/attention_generic.muh
@@ -0,0 +1,65 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include
+
+namespace vllm {
+
+// A vector type to store Q, K, V elements.
+template
+struct Vec {};
+
+// A vector type to store FP32 accumulators.
+template
+struct FloatVec {};
+
+// Template vector operations.
+template
+inline __device__ Acc mul(A a, B b);
+
+template
+inline __device__ float sum(T v);
+
+template
+inline __device__ float dot(T a, T b) {
+ return sum(mul(a, b));
+}
+
+template
+inline __device__ float dot(T a, T b) {
+ return sum(mul(a, b));
+}
+
+template
+inline __device__ void zero(T& dst) {
+ constexpr int WORDS = sizeof(T) / 4;
+ union {
+ T raw;
+ uint32_t words[WORDS];
+ } tmp;
+
+#pragma unroll
+ for (int ii = 0; ii < WORDS; ++ii) {
+ tmp.words[ii] = 0u;
+ }
+ dst = tmp.raw;
+}
+
+} // namespace vllm
diff --git a/csrc_musa/attention/attention_kernels.mu b/csrc_musa/attention/attention_kernels.mu
new file mode 100644
index 0000000..6bade47
--- /dev/null
+++ b/csrc_musa/attention/attention_kernels.mu
@@ -0,0 +1,981 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include
+#include "torch_musa/csrc/aten/musa/MUSAContext.h"
+#include "torch_musa/csrc/core/MUSAGuard.h"
+
+#include "attention_dtypes.h"
+#include "attention_utils.muh"
+
+#if defined(ENABLE_FP8_E5M2)
+#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#elif defined(ENABLE_FP8_E4M3)
+#include "../quantization/fp8/amd_detail/quant_utils.cuh"
+#endif
+
+#include
+
+#ifdef USE_ROCM
+ #include
+ typedef __hip_bfloat16 __mt_bfloat16;
+#endif
+
+#ifndef USE_ROCM
+#define WARP_SIZE 32
+#else
+#define WARP_SIZE warpSize
+#endif
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
+
+namespace vllm {
+
+// Utility function for attention softmax.
+template
+inline __device__ float block_sum(float* red_smem, float sum) {
+ // Decompose the thread index into warp / lane.
+ int warp = threadIdx.x / WARP_SIZE;
+ int lane = threadIdx.x % WARP_SIZE;
+
+ // Compute the sum per warp.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
+ }
+
+ // Warp leaders store the data to shared memory.
+ if (lane == 0) {
+ red_smem[warp] = sum;
+ }
+
+ // Make sure the data is in shared memory.
+ __syncthreads();
+
+ // The warps compute the final sums.
+ if (lane < NUM_WARPS) {
+ sum = red_smem[lane];
+ }
+
+ // Parallel reduction inside the warp.
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ sum += VLLM_SHFL_XOR_SYNC(sum, mask);
+ }
+
+ // Broadcast to other threads.
+ return VLLM_SHFL_SYNC(sum, 0);
+}
+
+// TODO(woosuk): Merge the last two dimensions of the grid.
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+ typename scalar_t,
+ typename cache_t,
+ int HEAD_SIZE,
+ int BLOCK_SIZE,
+ int NUM_THREADS,
+ bool IS_FP8_KV_CACHE,
+ int PARTITION_SIZE = 0> // Zero means no partitioning.
+__device__ void paged_attention_kernel(
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
+ const int num_kv_heads, // [num_heads]
+ const float scale,
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int* __restrict__ seq_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride,
+ const int kv_block_stride,
+ const int kv_head_stride,
+ const float kv_scale) {
+ const int seq_idx = blockIdx.y;
+ const int partition_idx = blockIdx.z;
+ const int max_num_partitions = gridDim.z;
+ constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
+ const int seq_len = seq_lens[seq_idx];
+ if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
+ // No work to do. Terminate the thread block.
+ return;
+ }
+
+ const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
+ const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
+
+ // [start_block_idx, end_block_idx) is the range of blocks to process.
+ const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
+ const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
+ const int num_blocks = end_block_idx - start_block_idx;
+
+ // [start_token_idx, end_token_idx) is the range of tokens to process.
+ const int start_token_idx = start_block_idx * BLOCK_SIZE;
+ const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
+ const int num_tokens = end_token_idx - start_token_idx;
+
+ constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS
+ assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
+ constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ const int thread_idx = threadIdx.x;
+ const int warp_idx = thread_idx / WARP_SIZE;
+ const int lane = thread_idx % WARP_SIZE;
+
+ const int head_idx = blockIdx.x;
+ const int num_heads = gridDim.x;
+ const int num_queries_per_kv = num_heads / num_kv_heads;
+ const int kv_head_idx = head_idx / num_queries_per_kv;
+ const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
+
+ // A vector type to store a part of a key or a query.
+ // The vector size is configured in such a way that the threads in a thread group
+ // fetch or compute 16 bytes at a time.
+ // For example, if the size of a thread group is 4 and the data type is half,
+ // then the vector size is 16 / (4 * sizeof(half)) == 2.
+ constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
+ using K_vec = typename Vec::Type;
+ using Q_vec = typename Vec::Type;
+#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
+ using Quant_vec = typename Vec::Type;
+#endif
+
+ constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
+ constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
+
+ const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
+ const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
+
+ // Load the query to registers.
+ // Each thread in a thread group has a different part of the query.
+ // For example, if the the thread group size is 4, then the first thread in the group
+ // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ...
+ // th vectors of the query, and so on.
+ // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous.
+ const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
+ __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
+#pragma unroll
+ for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) {
+ const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
+ q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE);
+ }
+ __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs
+
+ // Memory planning.
+ extern __shared__ char shared_mem[];
+ // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
+ float* logits = reinterpret_cast(shared_mem);
+ // Workspace for reduction.
+ __shared__ float red_smem[2 * NUM_WARPS];
+
+ // x == THREAD_GROUP_SIZE * VEC_SIZE
+ // Each thread group fetches x elements from the key at a time.
+ constexpr int x = 16 / sizeof(cache_t);
+ float qk_max = -FLT_MAX;
+
+ // Iterate over the key blocks.
+ // Each warp fetches a block of keys for each iteration.
+ // Each thread group in a warp fetches a key from the block, and computes
+ // dot product with the query.
+ const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
+ // because int32 can lead to overflow when this variable is multiplied by large numbers
+ // (e.g., kv_block_stride).
+ const int64_t physical_block_number = static_cast(block_table[block_idx]);
+
+ // Load a key to registers.
+ // Each thread in a thread group has a different part of the key.
+ // For example, if the the thread group size is 4, then the first thread in the group
+ // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th
+ // vectors of the key, and so on.
+ for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
+ const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ K_vec k_vecs[NUM_VECS_PER_THREAD];
+
+#pragma unroll
+ for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
+ const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
+ + kv_head_idx * kv_head_stride
+ + physical_block_offset * x;
+ const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
+ const int offset1 = (vec_idx * VEC_SIZE) / x;
+ const int offset2 = (vec_idx * VEC_SIZE) % x;
+ if constexpr (IS_FP8_KV_CACHE) {
+#if defined(ENABLE_FP8_E5M2)
+ Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ // Vector conversion from Quant_vec to K_vec.
+ k_vecs[j] = fp8_e5m2_unscaled::vec_conversion(k_vec_quant);
+#elif defined(ENABLE_FP8_E4M3)
+ Quant_vec k_vec_quant = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ // Vector conversion from Quant_vec to K_vec. Use scaled_vec_conversion to convert FP8_E4M3 quantized k
+ // cache vec to k vec in higher precision (FP16, BFloat16, etc.)
+ k_vecs[j] = fp8_e4m3::scaled_vec_conversion(k_vec_quant, kv_scale);
+#else
+ assert(false);
+#endif
+ } else {
+ k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
+ }
+ }
+
+ // Compute dot product.
+ // This includes a reduction across the threads in the same thread group.
+ float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs);
+ // Add the ALiBi bias if slopes are given.
+ qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
+
+ if (thread_group_offset == 0) {
+ // Store the partial reductions to shared memory.
+ // NOTE(woosuk): It is required to zero out the masked logits.
+ const bool mask = token_idx >= seq_len;
+ logits[token_idx - start_token_idx] = mask ? 0.f : qk;
+ // Update the max value.
+ qk_max = mask ? qk_max : fmaxf(qk_max, qk);
+ }
+ }
+ }
+
+ // Perform reduction across the threads in the same warp to get the
+ // max qk value for each "warp" (not across the thread block yet).
+ // The 0-th thread of each thread group already has its max qk value.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = qk_max;
+ }
+ __syncthreads();
+
+ // TODO(woosuk): Refactor this part.
+ // Get the max qk value for the sequence.
+ qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
+ }
+ // Broadcast the max qk value to all threads.
+ qk_max = VLLM_SHFL_SYNC(qk_max, 0);
+
+ // Get the sum of the exp values.
+ float exp_sum = 0.f;
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ float val = __expf(logits[i] - qk_max);
+ logits[i] = val;
+ exp_sum += val;
+ }
+ exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum);
+
+ // Compute softmax.
+ const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
+ for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
+ logits[i] *= inv_sum;
+ }
+ __syncthreads();
+
+ // If partitioning is enabled, store the max logit and exp_sum.
+ if (USE_PARTITIONING && thread_idx == 0) {
+ float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions
+ + partition_idx;
+ *max_logits_ptr = qk_max;
+ float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions
+ + partition_idx;
+ *exp_sums_ptr = exp_sum;
+ }
+
+ // Each thread will fetch 16 bytes from the value cache at a time.
+ constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
+ using V_vec = typename Vec::Type;
+ using L_vec = typename Vec::Type;
+#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
+ using V_quant_vec = typename Vec::Type;
+#endif
+ using Float_L_vec = typename FloatVec::Type;
+
+ constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
+ constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
+ constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
+
+ // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
+ float accs[NUM_ROWS_PER_THREAD];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ accs[i] = 0.f;
+ }
+
+ scalar_t zero_value;
+ zero(zero_value);
+ for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) {
+ // NOTE(woosuk): The block number is stored in int32. However, we cast it to int64
+ // because int32 can lead to overflow when this variable is multiplied by large numbers
+ // (e.g., kv_block_stride).
+ const int64_t physical_block_number = static_cast(block_table[block_idx]);
+ const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
+ const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
+ L_vec logits_vec;
+ from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx));
+
+ const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
+ + kv_head_idx * kv_head_stride;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE) {
+ const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
+ V_vec v_vec;
+ if constexpr (IS_FP8_KV_CACHE) {
+#if defined(ENABLE_FP8_E5M2)
+ V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset);
+ // Vector conversion from V_quant_vec to V_vec.
+ v_vec = fp8_e5m2_unscaled::vec_conversion(v_quant_vec);
+#elif defined(ENABLE_FP8_E4M3)
+ V_quant_vec v_quant_vec = *reinterpret_cast(v_ptr + offset);
+ // Vector conversion from V_quant_vec to V_vec. Use scaled_vec_conversion to convert
+ // FP8_E4M3 quantized v cache vec to v vec in higher precision (FP16, BFloat16, etc.)
+ v_vec = fp8_e4m3::scaled_vec_conversion(v_quant_vec, kv_scale);
+#else
+ assert(false);
+#endif
+ } else {
+ v_vec = *reinterpret_cast(v_ptr + offset);
+ }
+ if (block_idx == num_seq_blocks - 1) {
+ // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
+ // we should explicitly zero out the values since they may contain NaNs.
+ // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
+ scalar_t* v_vec_ptr = reinterpret_cast(&v_vec);
+#pragma unroll
+ for (int j = 0; j < V_VEC_SIZE; j++) {
+ v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
+ }
+ }
+ accs[i] += dot(logits_vec, v_vec);
+ }
+ }
+ }
+
+ // Perform reduction within each warp.
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ float acc = accs[i];
+#pragma unroll
+ for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
+ acc += VLLM_SHFL_XOR_SYNC(acc, mask);
+ }
+ accs[i] = acc;
+ }
+
+ // NOTE(woosuk): A barrier is required because the shared memory space for logits
+ // is reused for the output.
+ __syncthreads();
+
+ // Perform reduction across warps.
+ float* out_smem = reinterpret_cast(shared_mem);
+#pragma unroll
+ for (int i = NUM_WARPS; i > 1; i /= 2) {
+ int mid = i / 2;
+ // Upper warps write to shared memory.
+ if (warp_idx >= mid && warp_idx < i) {
+ float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ dst[row_idx] = accs[i];
+ }
+ }
+ }
+ __syncthreads();
+
+ // Lower warps update the output.
+ if (warp_idx < mid) {
+ const float* src = &out_smem[warp_idx * HEAD_SIZE];
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ accs[i] += src[row_idx];
+ }
+ }
+ }
+ __syncthreads();
+ }
+
+ // Write the final output.
+ if (warp_idx == 0) {
+ scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ + head_idx * max_num_partitions * HEAD_SIZE
+ + partition_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
+ const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
+ if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
+ from_float(*(out_ptr + row_idx), accs[i]);
+ }
+ }
+ }
+}
+
+// Grid: (num_heads, num_seqs, 1).
+template<
+ typename scalar_t,
+ typename cache_t,
+ int HEAD_SIZE,
+ int BLOCK_SIZE,
+ int NUM_THREADS,
+ bool IS_FP8_KV_CACHE>
+__global__ void paged_attention_v1_kernel(
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
+ const int num_kv_heads, // [num_heads]
+ const float scale,
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int* __restrict__ seq_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride,
+ const int kv_block_stride,
+ const int kv_head_stride,
+ const float kv_scale) {
+ paged_attention_kernel(
+ /* exp_sums */ nullptr, /* max_logits */ nullptr,
+ out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens,
+ max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, kv_scale);
+}
+
+// Grid: (num_heads, num_seqs, max_num_partitions).
+template<
+ typename scalar_t,
+ typename cache_t,
+ int HEAD_SIZE,
+ int BLOCK_SIZE,
+ int NUM_THREADS,
+ bool IS_FP8_KV_CACHE,
+ int PARTITION_SIZE>
+__global__ void paged_attention_v2_kernel(
+ float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
+ const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
+ const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
+ const int num_kv_heads, // [num_heads]
+ const float scale,
+ const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
+ const int* __restrict__ seq_lens, // [num_seqs]
+ const int max_num_blocks_per_seq,
+ const float* __restrict__ alibi_slopes, // [num_heads]
+ const int q_stride,
+ const int kv_block_stride,
+ const int kv_head_stride,
+ const float kv_scale) {
+ paged_attention_kernel(
+ exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
+ block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes,
+ q_stride, kv_block_stride, kv_head_stride, kv_scale);
+}
+
+// Grid: (num_heads, num_seqs).
+template<
+ typename scalar_t,
+ int HEAD_SIZE,
+ int NUM_THREADS,
+ int PARTITION_SIZE>
+__global__ void paged_attention_v2_reduce_kernel(
+ scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
+ const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
+ const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ const int* __restrict__ seq_lens, // [num_seqs]
+ const int max_num_partitions) {
+ const int num_heads = gridDim.x;
+ const int head_idx = blockIdx.x;
+ const int seq_idx = blockIdx.y;
+ const int seq_len = seq_lens[seq_idx];
+ const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
+ if (num_partitions == 1) {
+ // No need to reduce. Only copy tmp_out to out.
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ + head_idx * max_num_partitions * HEAD_SIZE;
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
+ out_ptr[i] = tmp_out_ptr[i];
+ }
+ // Terminate the thread block.
+ return;
+ }
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ const int warp_idx = threadIdx.x / WARP_SIZE;
+ const int lane = threadIdx.x % WARP_SIZE;
+
+ // Size: 2 * num_partitions.
+ extern __shared__ char shared_mem[];
+ // Workspace for reduction.
+ __shared__ float red_smem[2 * NUM_WARPS];
+
+ // Load max logits to shared memory.
+ float* shared_max_logits = reinterpret_cast(shared_mem);
+ const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions;
+ float max_logit = -FLT_MAX;
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+ const float l = max_logits_ptr[i];
+ shared_max_logits[i] = l;
+ max_logit = fmaxf(max_logit, l);
+ }
+ __syncthreads();
+
+ // Get the global max logit.
+ // Reduce within the warp.
+#pragma unroll
+ for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
+ }
+ if (lane == 0) {
+ red_smem[warp_idx] = max_logit;
+ }
+ __syncthreads();
+ // Reduce across warps.
+ max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
+#pragma unroll
+ for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
+ max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
+ }
+ // Broadcast the max value to all threads.
+ max_logit = VLLM_SHFL_SYNC(max_logit, 0);
+
+ // Load rescaled exp sums to shared memory.
+ float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions);
+ const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions
+ + head_idx * max_num_partitions;
+ float global_exp_sum = 0.0f;
+ for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
+ float l = shared_max_logits[i];
+ float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
+ global_exp_sum += rescaled_exp_sum;
+ shared_exp_sums[i] = rescaled_exp_sum;
+ }
+ __syncthreads();
+ global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum);
+ const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
+
+ // Aggregate tmp_out to out.
+ const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE
+ + head_idx * max_num_partitions * HEAD_SIZE;
+ scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
+#pragma unroll
+ for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
+ float acc = 0.0f;
+ for (int j = 0; j < num_partitions; ++j) {
+ acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum;
+ }
+ from_float(out_ptr[i], acc);
+ }
+}
+
+} // namespace vllm
+
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
+ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
+ ((void*)vllm::paged_attention_v1_kernel), shared_mem_size); \
+ vllm::paged_attention_v1_kernel<<>>( \
+ out_ptr, \
+ query_ptr, \
+ key_cache_ptr, \
+ value_cache_ptr, \
+ num_kv_heads, \
+ scale, \
+ block_tables_ptr, \
+ seq_lens_ptr, \
+ max_num_blocks_per_seq, \
+ alibi_slopes_ptr, \
+ q_stride, \
+ kv_block_stride, \
+ kv_head_stride, \
+ kv_scale);
+
+// TODO(woosuk): Tune NUM_THREADS.
+template<
+ typename T,
+ typename CACHE_T,
+ int BLOCK_SIZE,
+ bool IS_FP8_KV_CACHE,
+ int NUM_THREADS = 128>
+void paged_attention_v1_launcher(
+ torch::Tensor& out,
+ torch::Tensor& query,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ int num_kv_heads,
+ float scale,
+ torch::Tensor& block_tables,
+ torch::Tensor& seq_lens,
+ int max_seq_len,
+ const c10::optional& alibi_slopes,
+ float kv_scale) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ assert(head_size % thread_group_size == 0);
+
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr = alibi_slopes ?
+ reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T* out_ptr = reinterpret_cast(out.data_ptr());
+ T* query_ptr = reinterpret_cast(query.data_ptr());
+ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int* block_tables_ptr = block_tables.data_ptr();
+ int* seq_lens_ptr = seq_lens.data_ptr();
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
+ int logits_size = padded_max_seq_len * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+ // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
+ // Keep that in sync with the logic here!
+ int shared_mem_size = std::max(logits_size, outputs_size);
+
+ dim3 grid(num_heads, num_seqs, 1);
+ dim3 block(NUM_THREADS);
+ const at::musa::OptionalMUSAGuard device_guard(device_of(query));
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V1(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V1(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V1(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V1(112);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V1(128);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V1(256);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+ paged_attention_v1_launcher( \
+ out, \
+ query, \
+ key_cache, \
+ value_cache, \
+ num_kv_heads, \
+ scale, \
+ block_tables, \
+ seq_lens, \
+ max_seq_len, \
+ alibi_slopes, \
+ kv_scale);
+
+// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
+ switch (block_size) { \
+ case 8: \
+ CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
+ break; \
+ case 16: \
+ CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
+ break; \
+ case 32: \
+ CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+
+void paged_attention_v1(
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
+ int num_kv_heads, // [num_heads]
+ float scale,
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
+ torch::Tensor& seq_lens, // [num_seqs]
+ int block_size,
+ int max_seq_len,
+ const c10::optional& alibi_slopes,
+ const std::string& kv_cache_dtype,
+ float kv_scale) {
+ if (kv_cache_dtype == "auto") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, __mt_bfloat16, false);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else if (kv_cache_dtype == "fp8") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V1_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, uint8_t, true);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+}
+
+#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \
+ vllm::paged_attention_v2_kernel \
+ <<>>( \
+ exp_sums_ptr, \
+ max_logits_ptr, \
+ tmp_out_ptr, \
+ query_ptr, \
+ key_cache_ptr, \
+ value_cache_ptr, \
+ num_kv_heads, \
+ scale, \
+ block_tables_ptr, \
+ seq_lens_ptr, \
+ max_num_blocks_per_seq, \
+ alibi_slopes_ptr, \
+ q_stride, \
+ kv_block_stride, \
+ kv_head_stride, \
+ kv_scale); \
+ vllm::paged_attention_v2_reduce_kernel \
+ <<>>( \
+ out_ptr, \
+ exp_sums_ptr, \
+ max_logits_ptr, \
+ tmp_out_ptr, \
+ seq_lens_ptr, \
+ max_num_partitions);
+
+template<
+ typename T,
+ typename CACHE_T,
+ int BLOCK_SIZE,
+ bool IS_FP8_KV_CACHE,
+ int NUM_THREADS = 128,
+ int PARTITION_SIZE = 512>
+void paged_attention_v2_launcher(
+ torch::Tensor& out,
+ torch::Tensor& exp_sums,
+ torch::Tensor& max_logits,
+ torch::Tensor& tmp_out,
+ torch::Tensor& query,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ int num_kv_heads,
+ float scale,
+ torch::Tensor& block_tables,
+ torch::Tensor& seq_lens,
+ int max_seq_len,
+ const c10::optional& alibi_slopes,
+ float kv_scale) {
+ int num_seqs = query.size(0);
+ int num_heads = query.size(1);
+ int head_size = query.size(2);
+ int max_num_blocks_per_seq = block_tables.size(1);
+ int q_stride = query.stride(0);
+ int kv_block_stride = key_cache.stride(0);
+ int kv_head_stride = key_cache.stride(1);
+
+ int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
+ assert(head_size % thread_group_size == 0);
+
+ // NOTE: alibi_slopes is optional.
+ const float* alibi_slopes_ptr = alibi_slopes ?
+ reinterpret_cast(alibi_slopes.value().data_ptr())
+ : nullptr;
+
+ T* out_ptr = reinterpret_cast(out.data_ptr());
+ float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr());
+ float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr());
+ T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr());
+ T* query_ptr = reinterpret_cast(query.data_ptr());
+ CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr());
+ CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr());
+ int* block_tables_ptr = block_tables.data_ptr();
+ int* seq_lens_ptr = seq_lens.data_ptr();
+
+ constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
+ int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
+ int logits_size = PARTITION_SIZE * sizeof(float);
+ int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+
+ // For paged attention v2 kernel.
+ dim3 grid(num_heads, num_seqs, max_num_partitions);
+ int shared_mem_size = std::max(logits_size, outputs_size);
+ // For paged attention v2 reduce kernel.
+ dim3 reduce_grid(num_heads, num_seqs);
+ int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+
+ dim3 block(NUM_THREADS);
+ const at::musa::OptionalMUSAGuard device_guard(device_of(query));
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+ switch (head_size) {
+ // NOTE(woosuk): To reduce the compilation time, we only compile for the
+ // head sizes that we use in the model. However, we can easily extend this
+ // to support any head size which is a multiple of 16.
+ case 64:
+ LAUNCH_PAGED_ATTENTION_V2(64);
+ break;
+ case 80:
+ LAUNCH_PAGED_ATTENTION_V2(80);
+ break;
+ case 96:
+ LAUNCH_PAGED_ATTENTION_V2(96);
+ break;
+ case 112:
+ LAUNCH_PAGED_ATTENTION_V2(112);
+ break;
+ case 128:
+ LAUNCH_PAGED_ATTENTION_V2(128);
+ break;
+ case 256:
+ LAUNCH_PAGED_ATTENTION_V2(256);
+ break;
+ default:
+ TORCH_CHECK(false, "Unsupported head size: ", head_size);
+ break;
+ }
+}
+
+#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
+ paged_attention_v2_launcher( \
+ out, \
+ exp_sums, \
+ max_logits, \
+ tmp_out, \
+ query, \
+ key_cache, \
+ value_cache, \
+ num_kv_heads, \
+ scale, \
+ block_tables, \
+ seq_lens, \
+ max_seq_len, \
+ alibi_slopes, \
+ kv_scale);
+
+// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
+// 1, 2, 4, 64, 128, 256.
+#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_KV_CACHE) \
+ switch (block_size) { \
+ case 8: \
+ CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_KV_CACHE); \
+ break; \
+ case 16: \
+ CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_KV_CACHE); \
+ break; \
+ case 32: \
+ CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_KV_CACHE); \
+ break; \
+ default: \
+ TORCH_CHECK(false, "Unsupported block size: ", block_size); \
+ break; \
+ }
+
+void paged_attention_v2(
+ torch::Tensor& out, // [num_seqs, num_heads, head_size]
+ torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
+ torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
+ torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
+ torch::Tensor& query, // [num_seqs, num_heads, head_size]
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
+ int num_kv_heads, // [num_heads]
+ float scale,
+ torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
+ torch::Tensor& seq_lens, // [num_seqs]
+ int block_size,
+ int max_seq_len,
+ const c10::optional& alibi_slopes,
+ const std::string& kv_cache_dtype,
+ float kv_scale) {
+ if (kv_cache_dtype == "auto") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, __mt_bfloat16, false);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else if (kv_cache_dtype == "fp8") {
+ if (query.dtype() == at::ScalarType::Float) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::Half) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
+ } else if (query.dtype() == at::ScalarType::BFloat16) {
+ CALL_V2_LAUNCHER_BLOCK_SIZE(__mt_bfloat16, uint8_t, true);
+ } else {
+ TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
+ }
+ } else {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+}
+
+#undef WARP_SIZE
+#undef MAX
+#undef MIN
+#undef DIVIDE_ROUND_UP
diff --git a/csrc_musa/attention/attention_utils.muh b/csrc_musa/attention/attention_utils.muh
new file mode 100644
index 0000000..8993a64
--- /dev/null
+++ b/csrc_musa/attention/attention_utils.muh
@@ -0,0 +1,57 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "../musa_compat.h"
+#include "attention_dtypes.h"
+
+#include
+#include
+
+namespace vllm {
+
+// Q*K^T operation.
+template
+inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
+ using A_vec = typename FloatVec::Type;
+ // Compute the parallel products for Q*K^T (treat vector lanes separately).
+ A_vec qk_vec = mul(q[0], k[0]);
+#pragma unroll
+ for (int ii = 1; ii < N; ++ii) {
+ qk_vec = fma(q[ii], k[ii], qk_vec);
+ }
+
+ // Finalize the reduction across lanes.
+ float qk = sum(qk_vec);
+#pragma unroll
+ for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
+ qk += VLLM_SHFL_XOR_SYNC(qk, mask);
+ }
+ return qk;
+}
+
+template
+struct Qk_dot {
+ template
+ static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) {
+ return qk_dot_(q, k);
+ }
+};
+
+} // namespace vllm
diff --git a/csrc_musa/attention/dtype_bfloat16.muh b/csrc_musa/attention/dtype_bfloat16.muh
new file mode 100644
index 0000000..5526476
--- /dev/null
+++ b/csrc_musa/attention/dtype_bfloat16.muh
@@ -0,0 +1,452 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.muh"
+#include "dtype_float32.muh"
+
+#ifndef USE_ROCM
+ #include
+ #include
+#else
+ #include
+ #include
+
+ typedef __hip_bfloat162 __mt_bfloat162;
+ typedef __hip_bfloat16 __mt_bfloat16;
+#endif
+
+#include
+
+namespace vllm {
+
+// Define custom BF16 vector data types.
+struct bf16_4_t {
+ __mt_bfloat162 x;
+ __mt_bfloat162 y;
+};
+
+struct bf16_8_t {
+ __mt_bfloat162 x;
+ __mt_bfloat162 y;
+ __mt_bfloat162 z;
+ __mt_bfloat162 w;
+};
+
+// BF16 vector types for Q, K, V.
+template<>
+struct Vec<__mt_bfloat16, 1> {
+ using Type = __mt_bfloat16;
+};
+template<>
+struct Vec<__mt_bfloat16, 2> {
+ using Type = __mt_bfloat162;
+};
+template<>
+struct Vec<__mt_bfloat16, 4> {
+ using Type = bf16_4_t;
+};
+template<>
+struct Vec<__mt_bfloat16, 8> {
+ using Type = bf16_8_t;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec<__mt_bfloat16> {
+ using Type = float;
+};
+template<>
+struct FloatVec<__mt_bfloat162> {
+ using Type = float2;
+};
+template<>
+struct FloatVec {
+ using Type = Float4_;
+};
+template<>
+struct FloatVec {
+ using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ float2 bf1622float2(const __mt_bfloat162 val) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __bfloat1622float2(val);
+#endif
+}
+
+inline __device__ __mt_bfloat162 bf162bf162(const __mt_bfloat16 val) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __bfloat162bfloat162(val);
+#endif
+}
+
+// Vector addition.
+inline __device__ __mt_bfloat16 add(__mt_bfloat16 a, __mt_bfloat16 b) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ #ifndef USE_ROCM
+ return a + b;
+ #else
+ return __hadd(a, b);
+ #endif
+#endif
+}
+
+inline __device__ __mt_bfloat162 add(__mt_bfloat162 a, __mt_bfloat162 b) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __hadd2(a, b);
+#endif
+}
+
+inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
+ bf16_4_t c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) {
+ bf16_8_t c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+inline __device__ float2 add(__mt_bfloat162 a, float2 fb) {
+ float2 fa = bf1622float2(a);
+ return add(fa, fb);
+}
+
+inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) {
+ Float4_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ return fc;
+}
+
+inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) {
+ Float8_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ fc.z = add(a.z, fb.z);
+ fc.w = add(a.w, fb.w);
+ return fc;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ __mt_bfloat16 mul(__mt_bfloat16 a, __mt_bfloat16 b) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __hmul(a, b);
+#endif
+}
+
+template<>
+inline __device__ __mt_bfloat162 mul(__mt_bfloat162 a, __mt_bfloat162 b) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __hmul2(a, b);
+#endif
+}
+
+template<>
+inline __device__ __mt_bfloat162 mul(__mt_bfloat16 a, __mt_bfloat162 b) {
+ return mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(bf162bf162(a), b);
+}
+
+template<>
+inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) {
+ bf16_4_t c;
+ c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
+ c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
+ return c;
+}
+
+template<>
+inline __device__ bf16_4_t mul(__mt_bfloat16 a, bf16_4_t b) {
+ __mt_bfloat162 s = bf162bf162(a);
+ bf16_4_t c;
+ c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.x);
+ c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.y);
+ return c;
+}
+
+template<>
+inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) {
+ bf16_8_t c;
+ c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.x, b.x);
+ c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.y, b.y);
+ c.z = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.z, b.z);
+ c.w = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(a.w, b.w);
+ return c;
+}
+
+template<>
+inline __device__ bf16_8_t mul(__mt_bfloat16 a, bf16_8_t b) {
+ __mt_bfloat162 s = bf162bf162(a);
+ bf16_8_t c;
+ c.x = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.x);
+ c.y = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.y);
+ c.z = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.z);
+ c.w = mul<__mt_bfloat162, __mt_bfloat162, __mt_bfloat162>(s, b.w);
+ return c;
+}
+
+template<>
+inline __device__ float mul(__mt_bfloat16 a, __mt_bfloat16 b) {
+ float fa = __bfloat162float(a);
+ float fb = __bfloat162float(b);
+ return fa * fb;
+}
+
+template<>
+inline __device__ float2 mul(__mt_bfloat162 a, __mt_bfloat162 b) {
+ float2 fa = bf1622float2(a);
+ float2 fb = bf1622float2(b);
+ return mul(fa, fb);
+}
+
+template<>
+inline __device__ float2 mul(__mt_bfloat16 a, __mt_bfloat162 b) {
+ return mul(bf162bf162(a), b);
+}
+
+template<>
+inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) {
+ Float4_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float4_ mul(__mt_bfloat16 a, bf16_4_t b) {
+ __mt_bfloat162 s = bf162bf162(a);
+ Float4_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) {
+ Float8_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ fc.z = mul(a.z, b.z);
+ fc.w = mul(a.w, b.w);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(__mt_bfloat16 a, bf16_8_t b) {
+ __mt_bfloat162 s = bf162bf162(a);
+ Float8_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ fc.z = mul(s, b.z);
+ fc.w = mul(s, b.w);
+ return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ __mt_bfloat162 fma(__mt_bfloat162 a, __mt_bfloat162 b, __mt_bfloat162 c) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __hfma2(a, b, c);
+#endif
+}
+
+inline __device__ __mt_bfloat162 fma(__mt_bfloat16 a, __mt_bfloat162 b, __mt_bfloat162 c) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ return __hfma2(bf162bf162(a), b, c);
+#endif
+}
+
+inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
+ bf16_4_t d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ bf16_4_t fma(__mt_bfloat16 a, bf16_4_t b, bf16_4_t c) {
+ __mt_bfloat162 s = bf162bf162(a);
+ bf16_4_t d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ return d;
+}
+
+inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) {
+ bf16_8_t d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ bf16_8_t fma(__mt_bfloat16 a, bf16_8_t b, bf16_8_t c) {
+ __mt_bfloat162 s = bf162bf162(a);
+ bf16_8_t d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ d.z = fma(s, b.z, c.z);
+ d.w = fma(s, b.w, c.w);
+ return d;
+}
+
+inline __device__ float fma(__mt_bfloat16 a, __mt_bfloat16 b, float fc) {
+ return __bfloat162float(a) * __bfloat162float(b) + fc;
+}
+
+inline __device__ float2 fma(__mt_bfloat162 a, __mt_bfloat162 b, float2 fc) {
+ float2 fa = bf1622float2(a);
+ float2 fb = bf1622float2(b);
+ return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(__mt_bfloat16 a, __mt_bfloat162 b, float2 fc) {
+ return fma(bf162bf162(a), b, fc);
+}
+
+inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) {
+ Float4_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float4_ fma(__mt_bfloat16 a, bf16_4_t b, Float4_ fc) {
+ __mt_bfloat162 s = bf162bf162(a);
+ Float4_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) {
+ Float8_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ fd.z = fma(a.z, b.z, fc.z);
+ fd.w = fma(a.w, b.w, fc.w);
+ return fd;
+}
+
+inline __device__ Float8_ fma(__mt_bfloat16 a, bf16_8_t b, Float8_ fc) {
+ __mt_bfloat162 s = bf162bf162(a);
+ Float8_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ fd.z = fma(s, b.z, fc.z);
+ fd.w = fma(s, b.w, fc.w);
+ return fd;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(__mt_bfloat16 v) {
+ return __bfloat162float(v);
+}
+
+template<>
+inline __device__ float sum(__mt_bfloat162 v) {
+ float2 vf = bf1622float2(v);
+ return vf.x + vf.y;
+}
+
+template<>
+inline __device__ float sum(bf16_4_t v) {
+ return sum(v.x) + sum(v.y);
+}
+
+template<>
+inline __device__ float sum(bf16_8_t v) {
+ return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w);
+}
+
+// From float32 to bfloat16.
+inline __device__ void from_float(__mt_bfloat16& dst, float src) {
+ dst = __float2bfloat16(src);
+}
+
+inline __device__ void from_float(__mt_bfloat162& dst, float2 src) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ dst = __float22bfloat162_rn(src);
+#endif
+}
+
+inline __device__ void from_float(bf16_4_t& dst, Float4_ src) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ dst.x = __float22bfloat162_rn(src.x);
+ dst.y = __float22bfloat162_rn(src.y);
+#endif
+}
+
+inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ dst.x = __float22bfloat162_rn(src.x);
+ dst.y = __float22bfloat162_rn(src.y);
+ dst.z = __float22bfloat162_rn(src.z);
+ dst.w = __float22bfloat162_rn(src.w);
+#endif
+}
+
+// From bfloat16 to float32.
+inline __device__ float to_float(__mt_bfloat16 u) {
+ return __bfloat162float(u);
+}
+
+// Zero-out a variable.
+inline __device__ void zero(__mt_bfloat16& dst) {
+#if defined(__MUSA_ARCH__) && __MUSA_ARCH__ < 800
+ assert(false);
+#else
+ // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2.
+ dst = __ushort_as_bfloat16((unsigned short)0x0000U);
+#endif
+}
+
+} // namespace vllm
diff --git a/csrc_musa/attention/dtype_float16.muh b/csrc_musa/attention/dtype_float16.muh
new file mode 100644
index 0000000..ba4de2b
--- /dev/null
+++ b/csrc_musa/attention/dtype_float16.muh
@@ -0,0 +1,503 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.muh"
+#include "dtype_float32.muh"
+
+#ifdef USE_ROCM
+ #include
+#endif
+
+#include
+
+namespace vllm {
+
+// FP16 vector types for Q, K, V.
+template<>
+struct Vec {
+ using Type = uint16_t;
+};
+template<>
+struct Vec {
+ using Type = uint32_t;
+};
+template<>
+struct Vec {
+ using Type = uint2;
+};
+template<>
+struct Vec {
+ using Type = uint4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec {
+ using Type = float;
+};
+template<>
+struct FloatVec {
+ using Type = float2;
+};
+template<>
+struct FloatVec {
+ using Type = Float4_;
+};
+template<>
+struct FloatVec {
+ using Type = Float8_;
+};
+
+// Utility functions for type conversions.
+inline __device__ uint32_t h0_h0(uint16_t a) {
+#ifndef USE_ROCM
+ uint32_t b;
+ asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
+ return b;
+#else
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u16[0] = a;
+ tmp.u16[1] = a;
+ return tmp.u32;
+#endif
+}
+
+inline __device__ float half_to_float(uint16_t h) {
+ float f;
+#ifndef USE_ROCM
+ asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
+#else
+ asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
+#endif
+ return f;
+}
+
+inline __device__ float2 half2_to_float2(uint32_t v) {
+#ifndef USE_ROCM
+ uint16_t lo, hi;
+ asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
+ return make_float2(half_to_float(lo), half_to_float(hi));
+#else
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+ tmp.u32 = v;
+ float2 ret;
+ ret.x = half_to_float(tmp.u16[0]);
+ ret.y = half_to_float(tmp.u16[1]);
+ return ret;
+#endif
+}
+
+inline __device__ uint16_t float_to_half(float f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+#ifndef USE_ROCM
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
+#else
+ asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
+#endif
+ return tmp.u16[0];
+}
+
+inline __device__ uint32_t float2_to_half2(float2 f) {
+ union {
+ uint32_t u32;
+ uint16_t u16[2];
+ } tmp;
+#ifndef USE_ROCM
+ #if defined(__MUSA_ARCH__) && __MUSA_ARCH__ >= 800
+ asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
+ #else
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
+ asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
+ #endif
+#else
+ tmp.u16[0] = float_to_half(f.x);
+ tmp.u16[1] = float_to_half(f.y);
+#endif
+ return tmp.u32;
+}
+
+// Vector addition.
+inline __device__ uint16_t add(uint16_t a, uint16_t b) {
+ uint16_t c;
+#ifndef USE_ROCM
+ asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+#else
+ asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+inline __device__ uint32_t add(uint32_t a, uint32_t b) {
+ uint32_t c;
+#ifndef USE_ROCM
+ asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+#else
+ asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+inline __device__ uint2 add(uint2 a, uint2 b) {
+ uint2 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ uint4 add(uint4 a, uint4 b) {
+ uint4 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+inline __device__ float2 add(uint32_t a, float2 fb) {
+ float2 fa = half2_to_float2(a);
+ return add(fa, fb);
+}
+
+inline __device__ Float4_ add(uint2 a, Float4_ fb) {
+ Float4_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ return fc;
+}
+
+inline __device__ Float8_ add(uint4 a, Float8_ fb) {
+ Float8_ fc;
+ fc.x = add(a.x, fb.x);
+ fc.y = add(a.y, fb.y);
+ fc.z = add(a.z, fb.z);
+ fc.w = add(a.w, fb.w);
+ return fc;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
+ uint16_t c;
+#ifndef USE_ROCM
+ asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
+#else
+ asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+template<>
+inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
+ uint32_t c;
+#ifndef USE_ROCM
+ asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
+#else
+ asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
+#endif
+ return c;
+}
+
+template<>
+inline __device__ uint32_t mul(uint16_t a, uint32_t b) {
+ return mul(h0_h0(a), b);
+}
+
+template<>
+inline __device__ uint2 mul(uint2 a, uint2 b) {
+ uint2 c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ return c;
+}
+
+template<>
+inline __device__ uint2 mul(uint16_t a, uint2 b) {
+ uint32_t s = h0_h0(a);
+ uint2 c;
+ c.x = mul(s, b.x);
+ c.y = mul(s, b.y);
+ return c;
+}
+
+template<>
+inline __device__ uint4 mul(uint4 a, uint4 b) {
+ uint4 c;
+ c.x = mul(a.x, b.x);
+ c.y = mul(a.y, b.y);
+ c.z = mul(a.z, b.z);
+ c.w = mul(a.w, b.w);
+ return c;
+}
+
+template<>
+inline __device__ uint4 mul(uint16_t a, uint4 b) {
+ uint32_t s = h0_h0(a);
+ uint4 c;
+ c.x = mul(s, b.x);
+ c.y = mul(s, b.y);
+ c.z = mul(s, b.z);
+ c.w = mul(s, b.w);
+ return c;
+}
+
+template<>
+inline __device__ float mul(uint16_t a, uint16_t b) {
+ float fa = half_to_float(a);
+ float fb = half_to_float(b);
+ return fa * fb;
+}
+
+template<>
+inline __device__ float2 mul(uint32_t a, uint32_t b) {
+ float2 fa = half2_to_float2(a);
+ float2 fb = half2_to_float2(b);
+ return mul(fa, fb);
+}
+
+template<>
+inline __device__ float2 mul(uint16_t a, uint32_t b) {
+ return mul(h0_h0(a), b);
+}
+
+template<>
+inline __device__ Float4_ mul(uint2 a, uint2 b) {
+ Float4_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float4_ mul(uint16_t a, uint2 b) {
+ uint32_t s = h0_h0(a);
+ Float4_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(uint4 a, uint4 b) {
+ Float8_ fc;
+ fc.x = mul(a.x, b.x);
+ fc.y = mul(a.y, b.y);
+ fc.z = mul(a.z, b.z);
+ fc.w = mul(a.w, b.w);
+ return fc;
+}
+
+template<>
+inline __device__ Float8_ mul(uint16_t a, uint4 b) {
+ uint32_t s = h0_h0(a);
+ Float8_ fc;
+ fc.x = mul(s, b.x);
+ fc.y = mul(s, b.y);
+ fc.z = mul(s, b.z);
+ fc.w = mul(s, b.w);
+ return fc;
+}
+
+// Vector fused multiply-add.
+inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
+ uint32_t d;
+#ifndef USE_ROCM
+ asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
+#else
+ asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
+#endif
+ return d;
+}
+
+inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
+ return fma(h0_h0(a), b, c);
+}
+
+inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) {
+ uint2 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) {
+ uint32_t s = h0_h0(a);
+ uint2 d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ return d;
+}
+
+inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) {
+ uint4 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) {
+ uint32_t s = h0_h0(a);
+ uint4 d;
+ d.x = fma(s, b.x, c.x);
+ d.y = fma(s, b.y, c.y);
+ d.z = fma(s, b.z, c.z);
+ d.w = fma(s, b.w, c.w);
+ return d;
+}
+
+inline __device__ float fma(uint16_t a, uint16_t b, float fc) {
+ float fa = half_to_float(a);
+ float fb = half_to_float(b);
+ return fa * fb + fc;
+}
+
+inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) {
+ float2 fa = half2_to_float2(a);
+ float2 fb = half2_to_float2(b);
+ return fma(fa, fb, fc);
+}
+
+inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) {
+ return fma(h0_h0(a), b, fc);
+}
+
+inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) {
+ Float4_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) {
+ uint32_t s = h0_h0(a);
+ Float4_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ return fd;
+}
+
+inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) {
+ Float8_ fd;
+ fd.x = fma(a.x, b.x, fc.x);
+ fd.y = fma(a.y, b.y, fc.y);
+ fd.z = fma(a.z, b.z, fc.z);
+ fd.w = fma(a.w, b.w, fc.w);
+ return fd;
+}
+
+inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) {
+ uint32_t s = h0_h0(a);
+ Float8_ fd;
+ fd.x = fma(s, b.x, fc.x);
+ fd.y = fma(s, b.y, fc.y);
+ fd.z = fma(s, b.z, fc.z);
+ fd.w = fma(s, b.w, fc.w);
+ return fd;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(uint16_t v) {
+ return half_to_float(v);
+}
+
+template<>
+inline __device__ float sum(uint32_t v) {
+ float2 tmp = half2_to_float2(v);
+ return tmp.x + tmp.y;
+}
+
+template<>
+inline __device__ float sum(uint2 v) {
+ uint32_t c = add(v.x, v.y);
+ return sum(c);
+}
+
+template<>
+inline __device__ float sum(uint4 v) {
+ uint32_t c = add(v.x, v.y);
+ c = add(c, v.z);
+ c = add(c, v.w);
+ return sum(c);
+}
+
+// From float32 to float16.
+inline __device__ void from_float(uint16_t& dst, float src) {
+ dst = float_to_half(src);
+}
+
+inline __device__ void from_float(uint32_t& dst, float2 src) {
+ dst = float2_to_half2(src);
+}
+
+inline __device__ void from_float(uint2& dst, Float4_ src) {
+ dst.x = float2_to_half2(src.x);
+ dst.y = float2_to_half2(src.y);
+}
+
+inline __device__ void from_float(uint4& dst, Float8_ src) {
+ dst.x = float2_to_half2(src.x);
+ dst.y = float2_to_half2(src.y);
+ dst.z = float2_to_half2(src.z);
+ dst.w = float2_to_half2(src.w);
+}
+
+// From float16 to float32.
+inline __device__ float to_float(uint16_t u) {
+ return half_to_float(u);
+}
+
+inline __device__ float2 to_float(uint32_t u) {
+ return half2_to_float2(u);
+}
+
+inline __device__ Float4_ to_float(uint2 u) {
+ Float4_ tmp;
+ tmp.x = half2_to_float2(u.x);
+ tmp.y = half2_to_float2(u.y);
+ return tmp;
+}
+
+inline __device__ Float8_ to_float(uint4 u) {
+ Float8_ tmp;
+ tmp.x = half2_to_float2(u.x);
+ tmp.y = half2_to_float2(u.y);
+ tmp.z = half2_to_float2(u.z);
+ tmp.w = half2_to_float2(u.w);
+ return tmp;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(uint16_t& dst) {
+ dst = uint16_t(0);
+}
+
+} // namespace vllm
diff --git a/csrc_musa/attention/dtype_float32.muh b/csrc_musa/attention/dtype_float32.muh
new file mode 100644
index 0000000..7eaffc3
--- /dev/null
+++ b/csrc_musa/attention/dtype_float32.muh
@@ -0,0 +1,274 @@
+/*
+ * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
+ * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
+ * Copyright (c) 2024 - 2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
+ * Copyright (c) 2023, The vLLM team.
+ * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+#pragma once
+
+#include "attention_generic.muh"
+
+#include
+
+namespace vllm {
+
+// Define custom FP32 vector data types.
+struct Float4_ {
+ float2 x;
+ float2 y;
+};
+
+struct Float8_ {
+ float2 x;
+ float2 y;
+ float2 z;
+ float2 w;
+};
+
+// FP32 vector types for Q, K, V.
+template<>
+struct Vec {
+ using Type = float;
+};
+template<>
+struct Vec {
+ using Type = float2;
+};
+template<>
+struct Vec {
+ using Type = float4;
+};
+
+// FP32 accumulator vector types corresponding to Vec.
+template<>
+struct FloatVec {
+ using Type = float;
+};
+template<>
+struct FloatVec {
+ using Type = float2;
+};
+template<>
+struct FloatVec {
+ using Type = float4;
+};
+
+// Vector addition.
+inline __device__ float add(float a, float b) {
+ return a + b;
+}
+
+inline __device__ float2 add(float2 a, float2 b) {
+ float2 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ return c;
+}
+
+inline __device__ float4 add(float4 a, float4 b) {
+ float4 c;
+ c.x = add(a.x, b.x);
+ c.y = add(a.y, b.y);
+ c.z = add(a.z, b.z);
+ c.w = add(a.w, b.w);
+ return c;
+}
+
+// Vector multiplication.
+template<>
+inline __device__ float mul(float a, float b) {
+ return a * b;
+}
+
+template<>
+inline __device__ float2 mul(float2 a, float2 b) {
+ float2 c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ return c;
+}
+
+template<>
+inline __device__ float2 mul(float a, float2 b) {
+ float2 c;
+ c.x = a * b.x;
+ c.y = a * b.y;
+ return c;
+}
+
+template<>
+inline __device__ float4 mul(float4 a, float4 b) {
+ float4 c;
+ c.x = a.x * b.x;
+ c.y = a.y * b.y;
+ c.z = a.z * b.z;
+ c.w = a.w * b.w;
+ return c;
+}
+
+template<>
+inline __device__ float4 mul(float a, float4 b) {
+ float4 c;
+ c.x = a * b.x;
+ c.y = a * b.y;
+ c.z = a * b.z;
+ c.w = a * b.w;
+ return c;
+}
+
+// Vector fused multiply-add.
+inline __device__ float fma(float a, float b, float c) {
+ return a * b + c;
+}
+
+inline __device__ float2 fma(float2 a, float2 b, float2 c) {
+ float2 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ return d;
+}
+
+inline __device__ float2 fma(float a, float2 b, float2 c) {
+ float2 d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ return d;
+}
+
+inline __device__ float4 fma(float4 a, float4 b, float4 c) {
+ float4 d;
+ d.x = fma(a.x, b.x, c.x);
+ d.y = fma(a.y, b.y, c.y);
+ d.z = fma(a.z, b.z, c.z);
+ d.w = fma(a.w, b.w, c.w);
+ return d;
+}
+
+inline __device__ float4 fma(float a, float4 b, float4 c) {
+ float4 d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ d.z = fma(a, b.z, c.z);
+ d.w = fma(a, b.w, c.w);
+ return d;
+}
+
+inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) {
+ Float4_ d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ return d;
+}
+
+inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) {
+ Float8_ d;
+ d.x = fma(a, b.x, c.x);
+ d.y = fma(a, b.y, c.y);
+ d.z = fma(a, b.z, c.z);
+ d.w = fma(a, b.w, c.w);
+ return d;
+}
+
+// Vector sum.
+template<>
+inline __device__ float sum(float v) {
+ return v;
+}
+
+template<>
+inline __device__ float sum(float2 v) {
+ return v.x + v.y;
+}
+
+template<>
+inline __device__ float sum(float4 v) {
+ return v.x + v.y + v.z + v.w;
+}
+
+template<>
+inline __device__ float sum(Float4_ v) {
+ return v.x.x + v.x.y + v.y.x + v.y.y;
+}
+
+template<>
+inline __device__ float sum(Float8_ v) {
+ return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y;
+}
+
+// Vector dot product.
+inline __device__ float dot(float a, float b) {
+ return a * b;
+}
+
+inline __device__ float dot(float2 a, float2 b) {
+ float2 c = mul(a, b);
+ return c.x + c.y;
+}
+
+inline __device__ float dot(Float4_ a, Float4_ b) {
+ float2 acc = mul(a.x, b.x);
+ acc = fma(a.y, b.y, acc);
+ return acc.x + acc.y;
+}
+
+inline __device__ float dot(Float8_ a, Float8_ b) {
+ float2 acc = mul(a.x, b.x);
+ acc = fma(a.y, b.y, acc);
+ acc = fma(a.z, b.z, acc);
+ acc = fma(a.w, b.w, acc);
+ return acc.x + acc.y;
+}
+
+// From float to float.
+inline __device__ void from_float(float& dst, float src) {
+ dst = src;
+}
+
+inline __device__ void from_float(float2& dst, float2 src) {
+ dst = src;
+}
+
+inline __device__ void from_float(float4& dst, float4 src) {
+ dst = src;
+}
+
+// From float to float.
+inline __device__ float to_float(float u) {
+ return u;
+}
+
+inline __device__ float2 to_float(float2 u) {
+ return u;
+}
+
+inline __device__ float4 to_float(float4 u) {
+ return u;
+}
+
+inline __device__ Float4_ to_float(Float4_ u) {
+ return u;
+}
+
+inline __device__ Float8_ to_float(Float8_ u) {
+ return u;
+}
+
+// Zero-out a variable.
+inline __device__ void zero(float& dst) {
+ dst = 0.f;
+}
+
+} // namespace vllm
diff --git a/csrc_musa/attention/dtype_fp8.muh b/csrc_musa/attention/dtype_fp8.muh
new file mode 100644
index 0000000..845ad85
--- /dev/null
+++ b/csrc_musa/attention/dtype_fp8.muh
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "attention_generic.muh"
+
+#include
+#ifdef ENABLE_FP8_E5M2
+#include
+#endif
+
+namespace vllm {
+#if defined(ENABLE_FP8_E5M2) || defined(ENABLE_FP8_E4M3)
+// fp8 vector types for quantization of kv cache
+
+template<>
+struct Vec {
+ using Type = uint8_t;
+};
+
+template<>
+struct Vec {
+ using Type = uint16_t;
+};
+
+template<>
+struct Vec {
+ using Type = uint32_t;
+};
+
+template<>
+struct Vec {
+ using Type = uint2;
+};
+#endif // ENABLE_FP8_E5M2
+
+} // namespace vllm
diff --git a/csrc_musa/cache.h b/csrc_musa/cache.h
new file mode 100644
index 0000000..4c142ce
--- /dev/null
+++ b/csrc_musa/cache.h
@@ -0,0 +1,38 @@
+#pragma once
+
+#include
+
+#include
+#include
+
+void swap_blocks(
+ torch::Tensor& src,
+ torch::Tensor& dst,
+ const std::map& block_mapping);
+
+void copy_blocks(
+ std::vector& key_caches,
+ std::vector& value_caches,
+ const std::map>& block_mapping);
+
+void reshape_and_cache(
+ torch::Tensor& key,
+ torch::Tensor& value,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ torch::Tensor& slot_mapping,
+ const std::string& kv_cache_dtype,
+ const float kv_scale);
+
+void reshape_and_cache_flash(
+ torch::Tensor& key,
+ torch::Tensor& value,
+ torch::Tensor& key_cache,
+ torch::Tensor& value_cache,
+ torch::Tensor& slot_mapping,
+ const std::string& kv_cache_dtype);
+
+// Just for unittest
+void convert_fp8(
+ torch::Tensor& src_cache,
+ torch::Tensor& dst_cache);
diff --git a/csrc_musa/cache_kernels.mu b/csrc_musa/cache_kernels.mu
new file mode 100644
index 0000000..727d7a0
--- /dev/null
+++ b/csrc_musa/cache_kernels.mu
@@ -0,0 +1,419 @@
+#include
+#include "torch_musa/csrc/aten/musa/MUSAContext.h"
+#include "torch_musa/csrc/core/MUSAGuard.h"
+
+#include "musa_compat.h"
+#include "dispatch_utils.h"
+#if defined(ENABLE_FP8_E5M2)
+#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
+#elif defined(ENABLE_FP8_E4M3)
+#include "quantization/fp8/amd_detail/quant_utils.cuh"
+#endif
+
+#include
+#include
+#include
+#include
+
+#ifdef USE_ROCM
+ #include
+ typedef __hip_bfloat16 __mt_bfloat16;
+#endif
+
+void swap_blocks(
+ torch::Tensor& src,
+ torch::Tensor& dst,
+ const std::map& block_mapping) {
+ torch::Device src_device = src.device();
+ torch::Device dst_device = dst.device();
+ musaMemcpyKind memcpy_type;
+ if (src_device.is_cuda() && dst_device.is_cuda()) {
+ TORCH_CHECK(
+ src_device.index() == dst_device.index(),
+ "src and dst must be on the same GPU");
+ memcpy_type = musaMemcpyDeviceToDevice;
+ } else if (src_device.is_cuda() && dst_device.is_cpu()) {
+ memcpy_type = musaMemcpyDeviceToHost;
+ } else if (src_device.is_cpu() && dst_device.is_cuda()) {
+ memcpy_type = musaMemcpyHostToDevice;
+ } else {
+ TORCH_CHECK(false, "Invalid device combination");
+ }
+
+ char *src_ptr = static_cast(src.data_ptr());
+ char *dst_ptr = static_cast(dst.data_ptr());
+
+ const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
+ const at::musa::OptionalMUSAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+ // NOTE(woosuk): This can be slow if the number of blocks is large.
+ for (const auto& pair : block_mapping) {
+ int64_t src_block_number = pair.first;
+ int64_t dst_block_number = pair.second;
+ int64_t src_offset = src_block_number * block_size_in_bytes;
+ int64_t dst_offset = dst_block_number * block_size_in_bytes;
+ musaMemcpyAsync(
+ dst_ptr + dst_offset,
+ src_ptr + src_offset,
+ block_size_in_bytes,
+ memcpy_type,
+ stream);
+ }
+}
+
+namespace vllm {
+
+// Grid: (num_layers, num_pairs)
+template
+__global__ void copy_blocks_kernel(
+ int64_t* key_cache_ptrs,
+ int64_t* value_cache_ptrs,
+ const int64_t* __restrict__ block_mapping,
+ const int numel_per_block) {
+ const int layer_idx = blockIdx.x;
+ const int pair_idx = blockIdx.y;
+
+ scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]);
+ scalar_t* value_cache = reinterpret_cast(value_cache_ptrs[layer_idx]);
+ int64_t src_block_number = block_mapping[2 * pair_idx];
+ int64_t dst_block_number = block_mapping[2 * pair_idx + 1];
+
+ const int64_t src_block_offset = src_block_number * numel_per_block;
+ const int64_t dst_block_offset = dst_block_number * numel_per_block;
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ key_cache[dst_offset] = key_cache[src_offset];
+ }
+ for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) {
+ int64_t src_offset = src_block_offset + i;
+ int64_t dst_offset = dst_block_offset + i;
+ value_cache[dst_offset] = value_cache[src_offset];
+ }
+}
+
+} // namespace vllm
+
+void copy_blocks(
+ std::vector& key_caches,
+ std::vector& value_caches,
+ const std::map>& block_mapping) {
+ int num_layers = key_caches.size();
+ TORCH_CHECK(num_layers == value_caches.size());
+ if (num_layers == 0) {
+ return;
+ }
+ torch::Device cache_device = key_caches[0].device();
+ TORCH_CHECK(cache_device.is_cuda());
+
+ // Create data structures for the kernel.
+ // Create an array of pointers to the key and value caches.
+ int64_t key_cache_ptrs[num_layers];
+ int64_t value_cache_ptrs[num_layers];
+ for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) {
+ key_cache_ptrs[layer_idx] = reinterpret_cast(key_caches[layer_idx].data_ptr());
+ value_cache_ptrs[layer_idx] = reinterpret_cast(value_caches[layer_idx].data_ptr());
+ }
+ // Create block mapping array.
+ std::vector block_mapping_vec;
+ for (const auto& pair : block_mapping) {
+ int64_t src_block_number = pair.first;
+ for (int64_t dst_block_number : pair.second) {
+ block_mapping_vec.push_back(src_block_number);
+ block_mapping_vec.push_back(dst_block_number);
+ }
+ }
+ int64_t* block_mapping_array = block_mapping_vec.data();
+ int num_pairs = block_mapping_vec.size() / 2;
+
+ // Move the data structures to the GPU.
+ // NOTE: This synchronizes the CPU and GPU.
+ torch::Tensor key_cache_ptrs_tensor = torch::from_blob(
+ key_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+ torch::Tensor value_cache_ptrs_tensor = torch::from_blob(
+ value_cache_ptrs, {num_layers}, torch::kInt64).to(cache_device);
+ torch::Tensor block_mapping_tensor = torch::from_blob(
+ block_mapping_array, {2 * num_pairs}, torch::kInt64).to(cache_device);
+
+ // Launch the kernel.
+ const int numel_per_block = key_caches[0][0].numel();
+ dim3 grid(num_layers, num_pairs);
+ dim3 block(std::min(1024, numel_per_block));
+ const at::musa::OptionalMUSAGuard device_guard(cache_device);
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+ VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
+ key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
+ vllm::copy_blocks_kernel<<>>(
+ key_cache_ptrs_tensor.data_ptr(),
+ value_cache_ptrs_tensor.data_ptr(),
+ block_mapping_tensor.data_ptr(),
+ numel_per_block);
+ }));
+}
+
+namespace vllm {
+
+template
+__global__ void reshape_and_cache_kernel(
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
+ cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
+ const int64_t* __restrict__ slot_mapping, // [num_tokens]
+ const int key_stride,
+ const int value_stride,
+ const int num_heads,
+ const int head_size,
+ const int block_size,
+ const int x,
+ const float kv_scale) {
+ const int64_t token_idx = blockIdx.x;
+ const int64_t slot_idx = slot_mapping[token_idx];
+ if (slot_idx < 0) {
+ // Padding token that should be ignored.
+ return;
+ }
+
+ const int64_t block_idx = slot_idx / block_size;
+ const int64_t block_offset = slot_idx % block_size;
+
+ const int n = num_heads * head_size;
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ const int64_t src_key_idx = token_idx * key_stride + i;
+ const int64_t src_value_idx = token_idx * value_stride + i;
+
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int x_idx = head_offset / x;
+ const int x_offset = head_offset % x;
+
+ const int64_t tgt_key_idx = block_idx * num_heads * (head_size / x) * block_size * x
+ + head_idx * (head_size / x) * block_size * x
+ + x_idx * block_size * x
+ + block_offset * x
+ + x_offset;
+ const int64_t tgt_value_idx = block_idx * num_heads * head_size * block_size
+ + head_idx * head_size * block_size
+ + head_offset * block_size
+ + block_offset;
+ scalar_t tgt_key = key[src_key_idx];
+ scalar_t tgt_value = value[src_value_idx];
+ if constexpr (is_fp8_kv_cache) {
+#if defined(ENABLE_FP8_E5M2)
+ key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_key);
+ value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion(tgt_value);
+#elif defined(ENABLE_FP8_E4M3)
+ key_cache[tgt_key_idx] = fp8_e4m3::scaled_vec_conversion(tgt_key, kv_scale);
+ value_cache[tgt_value_idx] = fp8_e4m3::scaled_vec_conversion(tgt_value, kv_scale);
+#else
+ assert(false);
+#endif
+ } else {
+ key_cache[tgt_key_idx] = tgt_key;
+ value_cache[tgt_value_idx] = tgt_value;
+ }
+ }
+}
+
+template
+__global__ void reshape_and_cache_flash_kernel(
+ const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
+ const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
+ scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads, head_size]
+ scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads, head_size]
+ const int64_t* __restrict__ slot_mapping, // [num_tokens]
+ const int block_stride,
+ const int key_stride,
+ const int value_stride,
+ const int num_heads,
+ const int head_size,
+ const int block_size) {
+ const int64_t token_idx = blockIdx.x;
+ const int64_t slot_idx = slot_mapping[token_idx];
+ // NOTE: slot_idx can be -1 if the token is padded
+ if (slot_idx < 0) {
+ return;
+ }
+ const int64_t block_idx = slot_idx / block_size;
+ const int64_t block_offset = slot_idx % block_size;
+ const int n = num_heads * head_size;
+ for (int i = threadIdx.x; i < n; i += blockDim.x) {
+ const int64_t src_key_idx = token_idx * key_stride + i;
+ const int64_t src_value_idx = token_idx * value_stride + i;
+ const int head_idx = i / head_size;
+ const int head_offset = i % head_size;
+ const int64_t tgt_value_idx = block_idx * block_stride
+ + block_offset * num_heads * head_size
+ + head_idx * head_size
+ + head_offset;
+ k_cache[tgt_value_idx] = key[src_key_idx];
+ v_cache[tgt_value_idx] = value[src_value_idx];
+ }
+}
+} // namespace vllm
+
+#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
+ vllm::reshape_and_cache_kernel<<>>( \
+ reinterpret_cast(key.data_ptr()), \
+ reinterpret_cast(value.data_ptr()), \
+ reinterpret_cast(key_cache.data_ptr()), \
+ reinterpret_cast(value_cache.data_ptr()), \
+ slot_mapping.data_ptr(), \
+ key_stride, \
+ value_stride, \
+ num_heads, \
+ head_size, \
+ block_size, \
+ x, \
+ kv_scale);
+
+void reshape_and_cache(
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
+ torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
+ torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
+ torch::Tensor& slot_mapping, // [num_tokens]
+ const std::string& kv_cache_dtype,
+ const float kv_scale)
+{
+ int num_tokens = key.size(0);
+ int num_heads = key.size(1);
+ int head_size = key.size(2);
+ int block_size = key_cache.size(3);
+ int x = key_cache.size(4);
+
+ int key_stride = key.stride(0);
+ int value_stride = value.stride(0);
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * head_size, 512));
+ const at::musa::OptionalMUSAGuard device_guard(device_of(key));
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+ if (kv_cache_dtype == "auto") {
+ if (key.dtype() == at::ScalarType::Float) {
+ CALL_RESHAPE_AND_CACHE(float, float, false);
+ } else if (key.dtype() == at::ScalarType::Half) {
+ CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
+ } else if (key.dtype() == at::ScalarType::BFloat16) {
+ CALL_RESHAPE_AND_CACHE(__mt_bfloat16, __mt_bfloat16, false);
+ }
+ } else if (kv_cache_dtype == "fp8") {
+ if (key.dtype() == at::ScalarType::Float) {
+ CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
+ } else if (key.dtype() == at::ScalarType::Half) {
+ CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
+ } else if (key.dtype() == at::ScalarType::BFloat16) {
+ CALL_RESHAPE_AND_CACHE(__mt_bfloat16, uint8_t, true);
+ }
+ } else {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+}
+
+void reshape_and_cache_flash(
+ torch::Tensor& key, // [num_tokens, num_heads, head_size]
+ torch::Tensor& value, // [num_tokens, num_heads, head_size]
+ torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
+ torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
+ torch::Tensor& slot_mapping, // [num_tokens]
+ const std::string& kv_cache_dtype)
+{
+ // FIXME: only support auto datatype, does not support fp8
+ if (kv_cache_dtype != "auto") {
+ TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
+ }
+ int num_tokens = key.size(0);
+ int num_heads = key.size(1);
+ int head_size = key.size(2);
+ int block_size = k_cache.size(1);
+
+ int key_stride = key.stride(0);
+ int value_stride = value.stride(0);
+ int block_stride = k_cache.stride(0);
+ TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
+
+ dim3 grid(num_tokens);
+ dim3 block(std::min(num_heads * head_size, 512));
+ const at::musa::OptionalMUSAGuard device_guard(device_of(key));
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+ VLLM_DISPATCH_FLOATING_TYPES(
+ key.scalar_type(),
+ "reshape_and_cache_flash",
+ [&] {
+ vllm::reshape_and_cache_flash_kernel<<>>(
+ key.data_ptr(),
+ value.data_ptr(),
+ k_cache.data_ptr(),
+ v_cache.data_ptr(),
+ slot_mapping.data_ptr(),
+ block_stride,
+ key_stride,
+ value_stride,
+ num_heads,
+ head_size,
+ block_size);
+ });
+}
+
+namespace vllm {
+
+template
+__global__ void convert_fp8_kernel(
+ const Tin* __restrict__ src_cache,
+ Tout* __restrict__ dst_cache,
+ const int64_t block_stride) {
+ const int64_t block_idx = blockIdx.x;
+ for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
+ int64_t idx = block_idx * block_stride + i;
+#if defined(ENABLE_FP8_E5M2)
+ dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion(src_cache[idx]);
+#elif defined(ENABLE_FP8_E4M3)
+ dst_cache[idx] = fp8_e4m3::vec_conversion(src_cache[idx]);
+#else
+ assert(false);
+#endif
+ }
+}
+
+} // namespace vllm
+
+#define CALL_CONVERT_FP8(Tout, Tin) \
+ vllm::convert_fp8_kernel<<>>( \
+ reinterpret_cast(src_cache.data_ptr()), \
+ reinterpret_cast(dst_cache.data_ptr()), \
+ block_stride);
+
+void convert_fp8(
+ torch::Tensor& src_cache,
+ torch::Tensor& dst_cache)
+{
+ torch::Device src_device = src_cache.device();
+ torch::Device dst_device = dst_cache.device();
+ TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
+ TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU")
+ TORCH_CHECK(
+ src_device.index() == dst_device.index(),
+ "src and dst must be on the same GPU");
+ at::musa::OptionalMUSAGuard device_guard(src_device);
+
+ int64_t num_blocks = src_cache.size(0);
+ int64_t block_stride = src_cache.stride(0);
+
+ dim3 grid(num_blocks);
+ dim3 block(std::min(block_stride, int64_t(512)));
+ const musaStream_t stream = at::musa::getCurrentMUSAStream();
+
+ if (src_cache.dtype() == at::ScalarType::Float) {
+ CALL_CONVERT_FP8(uint8_t, float);
+ } else if (src_cache.dtype() == at::ScalarType::Half) {
+ CALL_CONVERT_FP8(uint8_t, uint16_t);
+ } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
+ CALL_CONVERT_FP8(uint8_t, __mt_bfloat16);
+ } else if (dst_cache.dtype() == at::ScalarType::Float) {
+ CALL_CONVERT_FP8(float, uint8_t);
+ } else if (dst_cache.dtype() == at::ScalarType::Half) {
+ CALL_CONVERT_FP8(uint16_t, uint8_t);
+ } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
+ CALL_CONVERT_FP8(__mt_bfloat16, uint8_t);
+ }
+}
diff --git a/csrc_musa/cpu/activation.cpp b/csrc_musa/cpu/activation.cpp
new file mode 100644
index 0000000..1bd24eb
--- /dev/null
+++ b/csrc_musa/cpu/activation.cpp
@@ -0,0 +1,148 @@
+#include "cpu_types.hpp"
+
+namespace {
+template
+void activation_kernel(int num_tokens, int d, scalar_t *__restrict__ input,
+ scalar_t *__restrict__ output) {
+ using scalar_vec_t = vec_op::vec_t