Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4536f74251 | |||
| 2bd9bd4cc2 | |||
| 5aef6c175a |
362
CMakeLists.txt
362
CMakeLists.txt
@@ -1,42 +1,48 @@
|
|||||||
cmake_minimum_required(VERSION 3.21)
|
cmake_minimum_required(VERSION 3.26)
|
||||||
|
|
||||||
|
# When building directly using CMake, make sure you run the install step
|
||||||
|
# (it places the .so files in the correct location).
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# mkdir build && cd build
|
||||||
|
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. ..
|
||||||
|
# cmake --build . --target install
|
||||||
|
#
|
||||||
|
# If you want to only build one target, make sure to install it manually:
|
||||||
|
# cmake --build . --target _C
|
||||||
|
# cmake --install . --component _C
|
||||||
project(vllm_extensions LANGUAGES CXX)
|
project(vllm_extensions LANGUAGES CXX)
|
||||||
|
|
||||||
option(VLLM_TARGET_DEVICE "Target device backend for vLLM" "musa")
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
|
|
||||||
|
|
||||||
|
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
|
||||||
|
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
|
||||||
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
|
||||||
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
|
||||||
|
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||||
|
|
||||||
|
# Suppress potential warnings about unused manually-specified variables
|
||||||
|
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||||
|
|
||||||
|
# Prevent installation of dependencies (cutlass) by default.
|
||||||
|
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
|
||||||
|
|
||||||
#
|
#
|
||||||
# Supported python versions. These versions will be searched in order, the
|
# Supported python versions. These versions will be searched in order, the
|
||||||
# first match will be selected. These should be kept in sync with setup.py.
|
# first match will be selected. These should be kept in sync with setup.py.
|
||||||
#
|
#
|
||||||
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
|
set(PYTHON_SUPPORTED_VERSIONS "3.10" "3.11" "3.12" "3.13")
|
||||||
|
|
||||||
# Supported NVIDIA architectures.
|
# ROCm installation prefix. Default to /opt/rocm but allow override via
|
||||||
# set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
# -DROCM_PATH=/your/rocm/path when invoking cmake.
|
||||||
|
if(NOT DEFINED ROCM_PATH)
|
||||||
# Supported MUSA architectures.
|
set(ROCM_PATH "/opt/rocm" CACHE PATH "ROCm installation prefix")
|
||||||
set(MUSA_SUPPORTED_ARCHS "220")
|
else()
|
||||||
|
set(ROCM_PATH ${ROCM_PATH} CACHE PATH "ROCm installation prefix" FORCE)
|
||||||
# Supported AMD GPU architectures.
|
endif()
|
||||||
# set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Supported/expected torch versions for CUDA/ROCm.
|
|
||||||
#
|
|
||||||
# Currently, having an incorrect pytorch version results in a warning
|
|
||||||
# rather than an error.
|
|
||||||
#
|
|
||||||
# Note: the CUDA torch version is derived from pyproject.toml and various
|
|
||||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
|
||||||
# versions are derived from Dockerfile.rocm
|
|
||||||
#
|
|
||||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.2.0")
|
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1")
|
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1")
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Try to find python package with an executable that exactly matches
|
# Try to find python package with an executable that exactly matches
|
||||||
@@ -55,53 +61,6 @@ endif()
|
|||||||
#
|
#
|
||||||
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
||||||
|
|
||||||
include(/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/cmake/utils.cmake)
|
|
||||||
|
|
||||||
add_definitions(-DTORCH_MUSA_ARCH=220)
|
|
||||||
set(MUSA_CSRCS)
|
|
||||||
set(CMAKE_MODULE_PATH /opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/cmake/modules)
|
|
||||||
set(DEPENDENT_LIBRARIES "")
|
|
||||||
set(DEPENDENT_INCLUDE_DIRS "")
|
|
||||||
find_package(MUDNN)
|
|
||||||
|
|
||||||
if(MUDNN_FOUND)
|
|
||||||
list(APPEND DEPENDENT_INCLUDE_DIRS ${MUDNN_INCLUDE_DIRS})
|
|
||||||
list(APPEND DEPENDENT_LIBRARIES ${MUDNN_LIBRARIES})
|
|
||||||
else()
|
|
||||||
message(WARNING " The environment variable MUSA_HOME may be not specified."
|
|
||||||
"Using default MUDNN PATH: /usr/local/musa")
|
|
||||||
|
|
||||||
list(APPEND DEPENDENT_INCLUDE_DIRS "/usr/local/musa/include")
|
|
||||||
list(APPEND DEPENDENT_LIBRARIES "/usr/local/musa/lib/libmudnn.so")
|
|
||||||
set(MUDNN_PATH "/usr/local/musa")
|
|
||||||
set(MUDNN_LIBRARIES "/usr/local/musa/lib/libmudnn.so")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
find_package(MUSAToolkits)
|
|
||||||
|
|
||||||
if(MUSAToolkits_FOUND)
|
|
||||||
list(APPEND DEPENDENT_INCLUDE_DIRS ${MUSAToolkits_INCLUDE_DIRS})
|
|
||||||
list(APPEND DEPENDENT_LIBRARIES ${MUSAToolkits_LIBRARIES})
|
|
||||||
else()
|
|
||||||
message(WARNING " The environment variable MUSA_HOME may be not specified."
|
|
||||||
"Using default MUSATOOLKITS PATH: /usr/local/musa")
|
|
||||||
|
|
||||||
list(APPEND DEPENDENT_INCLUDE_DIRS "/usr/local/musa/include/")
|
|
||||||
list(APPEND DEPENDENT_LIBRARIES "/usr/local/musa/lib/libmusart.so")
|
|
||||||
set(ENV{MUSA_HOME} "/usr/local/musa")
|
|
||||||
set(MUSATOOLKITS_PATH "/usr/local/musa")
|
|
||||||
set(MUSAToolkits_LIBRARIES "/usr/local/musa/lib/")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(DEFINED PYTHON_INCLUDE_DIR)
|
|
||||||
include_directories(${PYTHON_INCLUDE_DIR})
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR, "Cannot find installed Python head file directory")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
list(APPEND CMAKE_MODULE_PATH $ENV{MUSA_HOME}/cmake)
|
|
||||||
find_package(MUSA REQUIRED)
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Import torch cmake configuration.
|
# Import torch cmake configuration.
|
||||||
# Torch also imports CUDA (and partially HIP) languages with some customizations,
|
# Torch also imports CUDA (and partially HIP) languages with some customizations,
|
||||||
@@ -110,29 +69,15 @@ find_package(MUSA REQUIRED)
|
|||||||
#
|
#
|
||||||
find_package(Torch REQUIRED)
|
find_package(Torch REQUIRED)
|
||||||
|
|
||||||
#
|
|
||||||
# Normally `torch.utils.cpp_extension.CUDAExtension` would add
|
|
||||||
# `libtorch_python.so` for linking against an extension. Torch's cmake
|
|
||||||
# configuration does not include this library (presumably since the cmake
|
|
||||||
# config is used for standalone C++ binaries that link against torch).
|
|
||||||
# The `libtorch_python.so` library defines some of the glue code between
|
|
||||||
# torch/python via pybind and is required by VLLM extensions for this
|
|
||||||
# reason. So, add it by manually with `find_library` using torch's
|
|
||||||
# installed library path.
|
|
||||||
#
|
|
||||||
find_library(torch_python_LIBRARY torch_python PATHS
|
|
||||||
"${TORCH_INSTALL_PREFIX}/lib")
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||||
#
|
#
|
||||||
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
|
if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
|
||||||
NOT VLLM_TARGET_DEVICE STREQUAL "musa" AND
|
|
||||||
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
|
NOT VLLM_TARGET_DEVICE STREQUAL "rocm")
|
||||||
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
|
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
|
||||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
|
return()
|
||||||
endif()
|
endif()
|
||||||
return()
|
return()
|
||||||
endif()
|
endif()
|
||||||
@@ -141,226 +86,99 @@ endif()
|
|||||||
# Set up GPU language and check the torch version and warn if it isn't
|
# Set up GPU language and check the torch version and warn if it isn't
|
||||||
# what is expected.
|
# what is expected.
|
||||||
#
|
#
|
||||||
if (NOT HIP_FOUND AND MUSA_FOUND)
|
if (VLLM_TARGET_DEVICE STREQUAL "cuda")
|
||||||
set(VLLM_GPU_LANG "MUSA")
|
# Include CUDA specific configuration
|
||||||
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cuda.cmake)
|
||||||
if (NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_CUDA})
|
elseif(VLLM_TARGET_DEVICE STREQUAL "rocm")
|
||||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_CUDA} "
|
# Include ROCm specific configuration
|
||||||
"expected for CUDA build, saw ${Torch_VERSION} instead.")
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/hip.cmake)
|
||||||
endif()
|
elseif(VLLM_TARGET_DEVICE STREQUAL "cpu")
|
||||||
elseif(HIP_FOUND)
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
||||||
set(VLLM_GPU_LANG "HIP")
|
elseif(VLLM_TARGET_DEVICE STREQUAL "musa")
|
||||||
|
include(${CMAKE_CURRENT_LIST_DIR}/cmake/musa.cmake)
|
||||||
# Importing torch recognizes and sets up some HIP/ROCm configuration but does
|
|
||||||
# not let cmake recognize .hip files. In order to get cmake to understand the
|
|
||||||
# .hip extension automatically, HIP must be enabled explicitly.
|
|
||||||
enable_language(HIP)
|
|
||||||
|
|
||||||
# ROCm 5.x
|
|
||||||
if (ROCM_VERSION_DEV_MAJOR EQUAL 5 AND
|
|
||||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_5X})
|
|
||||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_5X} "
|
|
||||||
"expected for ROCMm 5.x build, saw ${Torch_VERSION} instead.")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# ROCm 6.x
|
|
||||||
if (ROCM_VERSION_DEV_MAJOR EQUAL 6 AND
|
|
||||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM_6X})
|
|
||||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM_6X} "
|
|
||||||
"expected for ROCMm 6.x build, saw ${Torch_VERSION} instead.")
|
|
||||||
endif()
|
|
||||||
else()
|
else()
|
||||||
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||||
# Override the GPU architectures detected by cmake/torch and filter them by
|
include(FetchContent)
|
||||||
# the supported versions for the current language.
|
file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists
|
||||||
# The final set of arches is stored in `VLLM_GPU_ARCHES`.
|
message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}")
|
||||||
#
|
|
||||||
# override_gpu_arches(VLLM_GPU_ARCHES
|
|
||||||
# ${VLLM_GPU_LANG}
|
|
||||||
# "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}")
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Query torch for additional GPU compilation flags for the given
|
# Define other extension targets
|
||||||
# `VLLM_GPU_LANG`.
|
|
||||||
# The final set of arches is stored in `VLLM_GPU_FLAGS`.
|
|
||||||
#
|
#
|
||||||
get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG})
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# Set nvcc parallelism.
|
# cumem_allocator extension
|
||||||
|
# Architecture-specific cumem configurations are included from cmake/cuda.cmake or cmake/hip.cmake
|
||||||
#
|
#
|
||||||
if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
|
||||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
|
message(STATUS "Enabling cumem allocator extension.")
|
||||||
|
define_extension_target(
|
||||||
|
cumem_allocator
|
||||||
|
DESTINATION vllm
|
||||||
|
LANGUAGE CXX
|
||||||
|
SOURCES ${VLLM_CUMEM_EXT_SRC}
|
||||||
|
LIBRARIES ${CUMEM_LIBS}
|
||||||
|
USE_SABI 3.8
|
||||||
|
WITH_SOABI)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
|
||||||
# Define extension targets
|
|
||||||
#
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# _C extension
|
# _C extension
|
||||||
#
|
#
|
||||||
|
|
||||||
set(VLLM_EXT_SRC
|
# VLLM_EXT_SRC is defined in the architecture-specific cmake files (cuda.cmake or hip.cmake)
|
||||||
"csrc_musa/cache_kernels.mu"
|
|
||||||
"csrc_musa/attention/attention_kernels.mu"
|
|
||||||
"csrc_musa/pos_encoding_kernels.mu"
|
|
||||||
"csrc_musa/activation_kernels.mu"
|
|
||||||
"csrc_musa/layernorm_kernels.mu"
|
|
||||||
"csrc_musa/quantization/squeezellm/quant_cuda_kernel.mu"
|
|
||||||
"csrc_musa/quantization/gptq/q_gemm.mu"
|
|
||||||
"csrc_musa/quantization/fp8/fp8_cuda_kernels.mu"
|
|
||||||
"csrc_musa/musa_utils_kernels.mu"
|
|
||||||
"csrc_musa/moe_align_block_size_kernels.mu"
|
|
||||||
"csrc_musa/pybind.cpp")
|
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "MUSA")
|
message(STATUS "Enabling C extension.")
|
||||||
list(APPEND VLLM_EXT_SRC
|
define_extension_target(
|
||||||
"csrc_musa/quantization/aqlm/gemm_kernels.mu"
|
_C
|
||||||
"csrc_musa/quantization/awq/gemm_kernels.mu"
|
|
||||||
"csrc_musa/quantization/marlin/marlin_cuda_kernel.mu"
|
|
||||||
"csrc_musa/quantization/gptq_marlin/gptq_marlin.mu"
|
|
||||||
"csrc_musa/quantization/gptq_marlin/gptq_marlin_repack.mu"
|
|
||||||
"csrc_musa/custom_all_reduce.mu")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
string(APPEND MUSA_MCC_FLAGS
|
|
||||||
|
|
||||||
)
|
|
||||||
string(APPEND MUSA_MCC_FLAGS " -U__CUDA__")
|
|
||||||
|
|
||||||
set(MUSA_VERBOSE_BUILD ON)
|
|
||||||
|
|
||||||
|
|
||||||
musa_include_directories(
|
|
||||||
/opt/conda/envs/py39/include/python3.9
|
|
||||||
/usr/local/musa/include
|
|
||||||
/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/generated_cuda_compatible/aten/src
|
|
||||||
/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/generated_cuda_compatible/include
|
|
||||||
/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/share/generated_cuda_compatible/include/torch/csrc/api/include
|
|
||||||
/opt/conda/envs/py39/lib/python3.9/site-packages
|
|
||||||
/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa
|
|
||||||
)
|
|
||||||
|
|
||||||
musa_add_library(vllm_C SHARED ${VLLM_EXT_SRC})
|
|
||||||
set(INSTALL_BIN_DIR "bin")
|
|
||||||
set(INSTALL_LIB_DIR "lib64")
|
|
||||||
set(INSTALL_INC_DIR "include")
|
|
||||||
set(INSTALL_SHARE_DIR "share")
|
|
||||||
set(INSTALL_DOC_DIR "docs")
|
|
||||||
|
|
||||||
define_gpu_extension_target(
|
|
||||||
vllm_C
|
|
||||||
DESTINATION vllm
|
DESTINATION vllm
|
||||||
LANGUAGE ${VLLM_GPU_LANG}
|
LANGUAGE ${VLLM_GPU_LANG}
|
||||||
SOURCES ${VLLM_EXT_SRC}
|
SOURCES ${VLLM_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
target_link_libraries(vllm_C ${DEPENDENT_LIBRARIES})
|
# If CUTLASS is compiled on NVCC >= 12.5, it by default uses
|
||||||
target_link_libraries(vllm_C "/opt/conda/envs/py39/lib/python3.9/site-packages/torch_musa/lib/libmusa_python.so")
|
# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the
|
||||||
#
|
# driver API. This causes problems when linking with earlier versions of CUDA.
|
||||||
|
# Setting this variable sidesteps the issue by calling the driver directly.
|
||||||
|
target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
|
||||||
|
|
||||||
# _moe_C extension
|
# _moe_C extension
|
||||||
#
|
# Architecture-specific MOE configurations are included from cmake/cuda.cmake or cmake/hip.cmake
|
||||||
|
|
||||||
set(VLLM_MOE_EXT_SRC
|
message(STATUS "Enabling moe extension.")
|
||||||
"csrc_musa/moe/moe_ops.cpp"
|
define_extension_target(
|
||||||
"csrc_musa/moe/topk_softmax_kernels.mu")
|
|
||||||
|
|
||||||
define_gpu_extension_target(
|
|
||||||
_moe_C
|
_moe_C
|
||||||
DESTINATION vllm
|
DESTINATION vllm
|
||||||
LANGUAGE ${VLLM_GPU_LANG}
|
LANGUAGE ${VLLM_GPU_LANG}
|
||||||
SOURCES ${VLLM_MOE_EXT_SRC}
|
SOURCES ${VLLM_MOE_EXT_SRC}
|
||||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||||
|
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||||
|
USE_SABI 3
|
||||||
WITH_SOABI)
|
WITH_SOABI)
|
||||||
|
|
||||||
#
|
# Architecture-specific ROCm configurations are included from cmake/hip.cmake
|
||||||
# _punica_C extension
|
|
||||||
#
|
|
||||||
|
|
||||||
set(VLLM_PUNICA_EXT_SRC
|
# For CUDA and HIP builds also build the triton_kernels external package.
|
||||||
"csrc_musa/punica/bgmv/bgmv_bf16_bf16_bf16.mu"
|
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||||
"csrc_musa/punica/bgmv/bgmv_bf16_fp32_bf16.mu"
|
include(cmake/external_projects/triton_kernels.cmake)
|
||||||
"csrc_musa/punica/bgmv/bgmv_fp16_fp16_fp16.mu"
|
|
||||||
"csrc_musa/punica/bgmv/bgmv_fp16_fp32_fp16.mu"
|
|
||||||
"csrc_musa/punica/bgmv/bgmv_fp32_bf16_bf16.mu"
|
|
||||||
"csrc_musa/punica/bgmv/bgmv_fp32_fp16_fp16.mu"
|
|
||||||
"csrc_musa/punica/punica_ops.cc")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Copy GPU compilation flags+update for punica
|
|
||||||
#
|
|
||||||
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
|
||||||
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
|
|
||||||
"-D__MUSA_NO_HALF_OPERATORS__"
|
|
||||||
"-D__MUSA_NO_HALF_CONVERSIONS__"
|
|
||||||
"-D__MUSA_NO_BFLOAT16_CONVERSIONS__"
|
|
||||||
"-D__MUSA_NO_HALF2_OPERATORS__")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Filter out CUDA architectures < 8.0 for punica.
|
|
||||||
#
|
|
||||||
# if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
|
||||||
# set(VLLM_PUNICA_GPU_ARCHES)
|
|
||||||
# foreach(ARCH ${VLLM_GPU_ARCHES})
|
|
||||||
# string_to_ver(CODE_VER ${ARCH})
|
|
||||||
# if (CODE_VER GREATER_EQUAL 8.0)
|
|
||||||
# list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
|
|
||||||
# endif()
|
|
||||||
# endforeach()
|
|
||||||
# message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
|
||||||
# endif()
|
|
||||||
|
|
||||||
if (VLLM_PUNICA_GPU_ARCHES)
|
|
||||||
define_gpu_extension_target(
|
|
||||||
_punica_C
|
|
||||||
DESTINATION vllm
|
|
||||||
LANGUAGE ${VLLM_GPU_LANG}
|
|
||||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
|
||||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
|
||||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
|
||||||
WITH_SOABI)
|
|
||||||
else()
|
|
||||||
message(WARNING "Unable to create _punica_C target because none of the "
|
|
||||||
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
#
|
# For CUDA we also build and ship some external projects.
|
||||||
# Add the `default` target which detects which extensions should be
|
if (VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
# built based on platform/architecture. This is the same logic that
|
include(cmake/external_projects/flashmla.cmake)
|
||||||
# setup.py uses to select which extensions should be built and should
|
include(cmake/external_projects/qutlass.cmake)
|
||||||
# be kept in sync.
|
|
||||||
#
|
|
||||||
# The `default` target makes direct use of cmake easier since knowledge
|
|
||||||
# of which extensions are supported has been factored in, e.g.
|
|
||||||
#
|
|
||||||
# mkdir build && cd build
|
|
||||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
|
||||||
# cmake --build . --target default
|
|
||||||
#
|
|
||||||
add_custom_target(default)
|
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "MUSA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
# vllm-flash-attn should be last as it overwrites some CMake functions
|
||||||
message(STATUS "Enabling C extension.")
|
include(cmake/external_projects/vllm_flash_attn.cmake)
|
||||||
add_dependencies(default _C)
|
endif ()
|
||||||
endif()
|
|
||||||
|
|
||||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "MUSA")
|
|
||||||
message(STATUS "Enabling moe extension.")
|
|
||||||
add_dependencies(default _moe_C)
|
|
||||||
|
|
||||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
|
||||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
|
||||||
# there are supported target arches.
|
|
||||||
if (VLLM_PUNICA_GPU_ARCHES AND
|
|
||||||
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
|
|
||||||
message(STATUS "Enabling punica extension.")
|
|
||||||
add_dependencies(default _punica_C)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|||||||
127
CODE_OF_CONDUCT.md
Normal file
127
CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
|
||||||
|
# vLLM Code of Conduct
|
||||||
|
|
||||||
|
## Our Pledge
|
||||||
|
|
||||||
|
We as members, contributors, and leaders pledge to make participation in our
|
||||||
|
community a harassment-free experience for everyone, regardless of age, body
|
||||||
|
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||||
|
identity and expression, level of experience, education, socioeconomic status,
|
||||||
|
nationality, personal appearance, race, caste, color, religion, or sexual
|
||||||
|
identity and orientation.
|
||||||
|
|
||||||
|
We pledge to act and interact in ways that contribute to an open, welcoming,
|
||||||
|
diverse, inclusive, and healthy community.
|
||||||
|
|
||||||
|
## Our Standards
|
||||||
|
|
||||||
|
Examples of behavior that contributes to a positive environment for our
|
||||||
|
community include:
|
||||||
|
|
||||||
|
* Demonstrating empathy and kindness toward other people
|
||||||
|
* Being respectful of differing opinions, viewpoints, and experiences
|
||||||
|
* Giving and gracefully accepting constructive feedback
|
||||||
|
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||||
|
and learning from the experience
|
||||||
|
* Focusing on what is best not just for us as individuals, but for the overall
|
||||||
|
community
|
||||||
|
|
||||||
|
Examples of unacceptable behavior include:
|
||||||
|
|
||||||
|
* The use of sexualized language or imagery, and sexual attention or advances of
|
||||||
|
any kind
|
||||||
|
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||||
|
* Public or private harassment
|
||||||
|
* Publishing others' private information, such as a physical or email address,
|
||||||
|
without their explicit permission
|
||||||
|
* Other conduct which could reasonably be considered inappropriate in a
|
||||||
|
professional setting
|
||||||
|
|
||||||
|
## Enforcement Responsibilities
|
||||||
|
|
||||||
|
Community leaders are responsible for clarifying and enforcing our standards of
|
||||||
|
acceptable behavior and will take appropriate and fair corrective action in
|
||||||
|
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||||
|
or harmful.
|
||||||
|
|
||||||
|
Community leaders have the right and responsibility to remove, edit, or reject
|
||||||
|
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||||
|
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||||
|
decisions when appropriate.
|
||||||
|
|
||||||
|
## Scope
|
||||||
|
|
||||||
|
This Code of Conduct applies within all community spaces, and also applies when
|
||||||
|
an individual is officially representing the community in public spaces.
|
||||||
|
Examples of representing our community include using an official email address,
|
||||||
|
posting via an official social media account, or acting as an appointed
|
||||||
|
representative at an online or offline/IRL event.
|
||||||
|
|
||||||
|
## Enforcement
|
||||||
|
|
||||||
|
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||||
|
reported to the community leaders responsible for enforcement in the #code-of-conduct
|
||||||
|
channel in the [vLLM Slack](https://slack.vllm.ai).
|
||||||
|
All complaints will be reviewed and investigated promptly and fairly.
|
||||||
|
|
||||||
|
All community leaders are obligated to respect the privacy and security of the
|
||||||
|
reporter of any incident.
|
||||||
|
|
||||||
|
## Enforcement Guidelines
|
||||||
|
|
||||||
|
Community leaders will follow these Community Impact Guidelines in determining
|
||||||
|
the consequences for any action they deem in violation of this Code of Conduct:
|
||||||
|
|
||||||
|
### 1. Correction
|
||||||
|
|
||||||
|
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||||
|
unprofessional or unwelcome in the community.
|
||||||
|
|
||||||
|
**Consequence**: A private, written warning from community leaders, providing
|
||||||
|
clarity around the nature of the violation and an explanation of why the
|
||||||
|
behavior was inappropriate. A public apology may be requested.
|
||||||
|
|
||||||
|
### 2. Warning
|
||||||
|
|
||||||
|
**Community Impact**: A violation through a single incident or series of
|
||||||
|
actions.
|
||||||
|
|
||||||
|
**Consequence**: A warning with consequences for continued behavior. No
|
||||||
|
interaction with the people involved, including unsolicited interaction with
|
||||||
|
those enforcing the Code of Conduct, for a specified period of time. This
|
||||||
|
includes avoiding interactions in community spaces as well as external channels
|
||||||
|
like social media. Violating these terms may lead to a temporary or permanent
|
||||||
|
ban.
|
||||||
|
|
||||||
|
### 3. Temporary Ban
|
||||||
|
|
||||||
|
**Community Impact**: A serious violation of community standards, including
|
||||||
|
sustained inappropriate behavior.
|
||||||
|
|
||||||
|
**Consequence**: A temporary ban from any sort of interaction or public
|
||||||
|
communication with the community for a specified period of time. No public or
|
||||||
|
private interaction with the people involved, including unsolicited interaction
|
||||||
|
with those enforcing the Code of Conduct, is allowed during this period.
|
||||||
|
Violating these terms may lead to a permanent ban.
|
||||||
|
|
||||||
|
### 4. Permanent Ban
|
||||||
|
|
||||||
|
**Community Impact**: Demonstrating a pattern of violation of community
|
||||||
|
standards, including sustained inappropriate behavior, harassment of an
|
||||||
|
individual, or aggression toward or disparagement of classes of individuals.
|
||||||
|
|
||||||
|
**Consequence**: A permanent ban from any sort of public interaction within the
|
||||||
|
community.
|
||||||
|
|
||||||
|
## Attribution
|
||||||
|
|
||||||
|
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/),
|
||||||
|
version 2.1, available at
|
||||||
|
[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html).
|
||||||
|
|
||||||
|
Community Impact Guidelines were inspired by
|
||||||
|
[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion).
|
||||||
|
|
||||||
|
For answers to common questions about this code of conduct, see the
|
||||||
|
[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at
|
||||||
|
[Contributor Covenant translations](https://www.contributor-covenant.org/translations).
|
||||||
@@ -1,56 +1,3 @@
|
|||||||
# Contributing to vLLM
|
# Contributing to vLLM
|
||||||
|
|
||||||
Thank you for your interest in contributing to vLLM!
|
You may find information about contributing to vLLM on [docs.vllm.ai](https://docs.vllm.ai/en/latest/contributing).
|
||||||
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
|
|
||||||
There are several ways you can contribute to the project:
|
|
||||||
|
|
||||||
- Identify and report any issues or bugs.
|
|
||||||
- Request or add a new model.
|
|
||||||
- Suggest or implement new features.
|
|
||||||
|
|
||||||
However, remember that contributions aren't just about code.
|
|
||||||
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
|
|
||||||
|
|
||||||
Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
|
|
||||||
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
|
|
||||||
Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
|
|
||||||
|
|
||||||
|
|
||||||
## Setup for development
|
|
||||||
|
|
||||||
### Build from source
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -e . # This may take several minutes.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Testing
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install -r requirements-dev.txt
|
|
||||||
|
|
||||||
# linting and formatting
|
|
||||||
bash format.sh
|
|
||||||
# Static type checking
|
|
||||||
mypy
|
|
||||||
# Unit tests
|
|
||||||
pytest tests/
|
|
||||||
```
|
|
||||||
**Note:** Currently, the repository does not pass the mypy tests.
|
|
||||||
|
|
||||||
|
|
||||||
## Contributing Guidelines
|
|
||||||
|
|
||||||
### Issue Reporting
|
|
||||||
|
|
||||||
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
|
|
||||||
If not, please file a new issue, providing as much relevant information as possible.
|
|
||||||
|
|
||||||
### Pull Requests & Code Reviews
|
|
||||||
|
|
||||||
Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE.md) for detailed guide for contribution.
|
|
||||||
|
|
||||||
### Thank You
|
|
||||||
|
|
||||||
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
|
|
||||||
Your contributions make vLLM a great tool for everyone!
|
|
||||||
|
|||||||
34
DCO
Normal file
34
DCO
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
Developer Certificate of Origin
|
||||||
|
Version 1.1
|
||||||
|
|
||||||
|
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
||||||
|
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies of this
|
||||||
|
license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
|
||||||
|
Developer's Certificate of Origin 1.1
|
||||||
|
|
||||||
|
By making a contribution to this project, I certify that:
|
||||||
|
|
||||||
|
(a) The contribution was created in whole or in part by me and I
|
||||||
|
have the right to submit it under the open source license
|
||||||
|
indicated in the file; or
|
||||||
|
|
||||||
|
(b) The contribution is based upon previous work that, to the best
|
||||||
|
of my knowledge, is covered under an appropriate open source
|
||||||
|
license and I have the right under that license to submit that
|
||||||
|
work with modifications, whether created in whole or in part
|
||||||
|
by me, under the same open source license (unless I am
|
||||||
|
permitted to submit under a different license), as indicated
|
||||||
|
in the file; or
|
||||||
|
|
||||||
|
(c) The contribution was provided directly to me by some other
|
||||||
|
person who certified (a), (b) or (c) and I have not modified
|
||||||
|
it.
|
||||||
|
|
||||||
|
(d) I understand and agree that this project and the contribution
|
||||||
|
are public and that a record of the contribution (including all
|
||||||
|
personal information I submit with it, including my sign-off) is
|
||||||
|
maintained indefinitely and may be redistributed consistent with
|
||||||
|
this project or the open source license(s) involved.
|
||||||
163
Dockerfile
163
Dockerfile
@@ -1,163 +0,0 @@
|
|||||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
|
||||||
# to run the OpenAI compatible server.
|
|
||||||
|
|
||||||
# Please update any changes made here to
|
|
||||||
# docs/source/dev/dockerfile/dockerfile.rst and
|
|
||||||
# docs/source/assets/dev/dockerfile-stages-dependency.png
|
|
||||||
|
|
||||||
#################### BASE BUILD IMAGE ####################
|
|
||||||
# prepare basic build environment
|
|
||||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS dev
|
|
||||||
|
|
||||||
RUN apt-get update -y \
|
|
||||||
&& apt-get install -y python3-pip git
|
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
|
||||||
# this won't be needed for future versions of this docker image
|
|
||||||
# or future versions of triton.
|
|
||||||
RUN ldconfig /usr/local/cuda-12.4/compat/
|
|
||||||
|
|
||||||
WORKDIR /workspace
|
|
||||||
|
|
||||||
# install build and runtime dependencies
|
|
||||||
COPY requirements-common.txt requirements-common.txt
|
|
||||||
COPY requirements-cuda.txt requirements-cuda.txt
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install -r requirements-cuda.txt
|
|
||||||
|
|
||||||
# install development dependencies
|
|
||||||
COPY requirements-dev.txt requirements-dev.txt
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install -r requirements-dev.txt
|
|
||||||
|
|
||||||
# cuda arch list used by torch
|
|
||||||
# can be useful for both `dev` and `test`
|
|
||||||
# explicitly set the list to avoid issues with torch 2.2
|
|
||||||
# see https://github.com/pytorch/pytorch/pull/123243
|
|
||||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
|
||||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
|
||||||
#################### BASE BUILD IMAGE ####################
|
|
||||||
|
|
||||||
|
|
||||||
#################### WHEEL BUILD IMAGE ####################
|
|
||||||
FROM dev AS build
|
|
||||||
|
|
||||||
# install build dependencies
|
|
||||||
COPY requirements-build.txt requirements-build.txt
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install -r requirements-build.txt
|
|
||||||
|
|
||||||
# install compiler cache to speed up compilation leveraging local or remote caching
|
|
||||||
RUN apt-get update -y && apt-get install -y ccache
|
|
||||||
|
|
||||||
# files and directories related to build wheels
|
|
||||||
COPY csrc csrc
|
|
||||||
COPY setup.py setup.py
|
|
||||||
COPY cmake cmake
|
|
||||||
COPY CMakeLists.txt CMakeLists.txt
|
|
||||||
COPY requirements-common.txt requirements-common.txt
|
|
||||||
COPY requirements-cuda.txt requirements-cuda.txt
|
|
||||||
COPY pyproject.toml pyproject.toml
|
|
||||||
COPY vllm vllm
|
|
||||||
|
|
||||||
# max jobs used by Ninja to build extensions
|
|
||||||
ARG max_jobs=2
|
|
||||||
ENV MAX_JOBS=${max_jobs}
|
|
||||||
# number of threads used by nvcc
|
|
||||||
ARG nvcc_threads=8
|
|
||||||
ENV NVCC_THREADS=$nvcc_threads
|
|
||||||
# make sure punica kernels are built (for LoRA)
|
|
||||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
|
||||||
|
|
||||||
ENV CCACHE_DIR=/root/.cache/ccache
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/ccache \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
|
||||||
python3 setup.py bdist_wheel --dist-dir=dist
|
|
||||||
|
|
||||||
# check the size of the wheel, we cannot upload wheels larger than 100MB
|
|
||||||
COPY .buildkite/check-wheel-size.py check-wheel-size.py
|
|
||||||
RUN python3 check-wheel-size.py dist
|
|
||||||
|
|
||||||
# the `vllm_nccl` package must be installed from source distribution
|
|
||||||
# pip is too smart to store a wheel in the cache, and other CI jobs
|
|
||||||
# will directly use the wheel from the cache, which is not what we want.
|
|
||||||
# we need to remove it manually
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip cache remove vllm_nccl*
|
|
||||||
#################### EXTENSION Build IMAGE ####################
|
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
|
||||||
FROM dev as flash-attn-builder
|
|
||||||
# max jobs used for build
|
|
||||||
ARG max_jobs=2
|
|
||||||
ENV MAX_JOBS=${max_jobs}
|
|
||||||
# flash attention version
|
|
||||||
ARG flash_attn_version=v2.5.8
|
|
||||||
ENV FLASH_ATTN_VERSION=${flash_attn_version}
|
|
||||||
|
|
||||||
WORKDIR /usr/src/flash-attention-v2
|
|
||||||
|
|
||||||
# Download the wheel or build it if a pre-compiled release doesn't exist
|
|
||||||
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
|
|
||||||
--no-build-isolation --no-deps --no-cache-dir
|
|
||||||
|
|
||||||
#################### FLASH_ATTENTION Build IMAGE ####################
|
|
||||||
|
|
||||||
#################### vLLM installation IMAGE ####################
|
|
||||||
# image with vLLM installed
|
|
||||||
FROM nvidia/cuda:12.4.1-base-ubuntu22.04 AS vllm-base
|
|
||||||
WORKDIR /vllm-workspace
|
|
||||||
|
|
||||||
RUN apt-get update -y \
|
|
||||||
&& apt-get install -y python3-pip git vim
|
|
||||||
|
|
||||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
|
||||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
|
||||||
# this won't be needed for future versions of this docker image
|
|
||||||
# or future versions of triton.
|
|
||||||
RUN ldconfig /usr/local/cuda-12.4/compat/
|
|
||||||
|
|
||||||
# install vllm wheel first, so that torch etc will be installed
|
|
||||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install dist/*.whl --verbose
|
|
||||||
|
|
||||||
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
|
|
||||||
--mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
|
|
||||||
#################### vLLM installation IMAGE ####################
|
|
||||||
|
|
||||||
|
|
||||||
#################### TEST IMAGE ####################
|
|
||||||
# image to run unit testing suite
|
|
||||||
# note that this uses vllm installed by `pip`
|
|
||||||
FROM vllm-base AS test
|
|
||||||
|
|
||||||
ADD . /vllm-workspace/
|
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install -r requirements-dev.txt
|
|
||||||
|
|
||||||
# doc requires source code
|
|
||||||
# we hide them inside `test_docs/` , so that this source code
|
|
||||||
# will not be imported by other tests
|
|
||||||
RUN mkdir test_docs
|
|
||||||
RUN mv docs test_docs/
|
|
||||||
RUN mv vllm test_docs/
|
|
||||||
|
|
||||||
#################### TEST IMAGE ####################
|
|
||||||
|
|
||||||
#################### OPENAI API SERVER ####################
|
|
||||||
# openai api server alternative
|
|
||||||
FROM vllm-base AS vllm-openai
|
|
||||||
|
|
||||||
# install additional dependencies for openai api server
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install accelerate hf_transfer modelscope
|
|
||||||
|
|
||||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
|
||||||
|
|
||||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
|
||||||
#################### OPENAI API SERVER ####################
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
|
|
||||||
|
|
||||||
FROM ubuntu:22.04
|
|
||||||
|
|
||||||
RUN apt-get update -y \
|
|
||||||
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \
|
|
||||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
|
||||||
|
|
||||||
RUN pip install --upgrade pip \
|
|
||||||
&& pip install wheel packaging ninja setuptools>=49.4.0 numpy
|
|
||||||
|
|
||||||
COPY ./ /workspace/vllm
|
|
||||||
|
|
||||||
WORKDIR /workspace/vllm
|
|
||||||
|
|
||||||
RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
|
||||||
|
|
||||||
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
|
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# default base image
|
|
||||||
ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04"
|
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
|
||||||
|
|
||||||
RUN echo "Base image is $BASE_IMAGE"
|
|
||||||
|
|
||||||
# Install some basic utilities
|
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
|
||||||
|
|
||||||
### Mount Point ###
|
|
||||||
# When launching the container, mount the code directory to /app
|
|
||||||
ARG APP_MOUNT=/app
|
|
||||||
VOLUME [ ${APP_MOUNT} ]
|
|
||||||
WORKDIR ${APP_MOUNT}
|
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
|
||||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
|
||||||
RUN python3 -m pip install sentencepiece transformers==4.36.2 -U
|
|
||||||
RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
|
||||||
RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U
|
|
||||||
|
|
||||||
COPY ./vllm /app/vllm/vllm
|
|
||||||
COPY ./setup.py /app/vllm/setup.py
|
|
||||||
COPY ./requirements-common.txt /app/vllm/requirements-common.txt
|
|
||||||
COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt
|
|
||||||
|
|
||||||
RUN cd /app/vllm \
|
|
||||||
&& python3 -m pip install -U -r requirements-neuron.txt
|
|
||||||
|
|
||||||
ENV VLLM_BUILD_WITH_NEURON 1
|
|
||||||
RUN cd /app/vllm \
|
|
||||||
&& pip install -e . \
|
|
||||||
&& cd ..
|
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
|
||||||
107
Dockerfile.rocm
107
Dockerfile.rocm
@@ -1,107 +0,0 @@
|
|||||||
# default base image
|
|
||||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
|
||||||
|
|
||||||
FROM $BASE_IMAGE
|
|
||||||
|
|
||||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
|
||||||
|
|
||||||
RUN echo "Base image is $BASE_IMAGE"
|
|
||||||
|
|
||||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
|
|
||||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
|
|
||||||
|
|
||||||
|
|
||||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
|
||||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
|
|
||||||
|
|
||||||
ARG FA_BRANCH="ae7928c"
|
|
||||||
RUN echo "FA_BRANCH is $FA_BRANCH"
|
|
||||||
|
|
||||||
# whether to build flash-attention
|
|
||||||
# if 0, will not build flash attention
|
|
||||||
# this is useful for gfx target where flash-attention is not supported
|
|
||||||
# In that case, we need to use the python reference attention implementation in vllm
|
|
||||||
ARG BUILD_FA="1"
|
|
||||||
|
|
||||||
# whether to build triton on rocm
|
|
||||||
ARG BUILD_TRITON="1"
|
|
||||||
|
|
||||||
# Install some basic utilities
|
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
|
||||||
|
|
||||||
# Install some basic utilities
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
curl \
|
|
||||||
ca-certificates \
|
|
||||||
sudo \
|
|
||||||
git \
|
|
||||||
bzip2 \
|
|
||||||
libx11-6 \
|
|
||||||
build-essential \
|
|
||||||
wget \
|
|
||||||
unzip \
|
|
||||||
nvidia-cuda-toolkit \
|
|
||||||
tmux \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
### Mount Point ###
|
|
||||||
# When launching the container, mount the code directory to /app
|
|
||||||
ARG APP_MOUNT=/vllm-workspace
|
|
||||||
VOLUME [ ${APP_MOUNT} ]
|
|
||||||
WORKDIR ${APP_MOUNT}
|
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
|
||||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
|
||||||
|
|
||||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
|
||||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
|
||||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
|
||||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
|
||||||
|
|
||||||
# Install ROCm flash-attention
|
|
||||||
RUN if [ "$BUILD_FA" = "1" ]; then \
|
|
||||||
mkdir libs \
|
|
||||||
&& cd libs \
|
|
||||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
|
||||||
&& cd flash-attention \
|
|
||||||
&& git checkout ${FA_BRANCH} \
|
|
||||||
&& git submodule update --init \
|
|
||||||
&& export GPU_ARCHS=${FA_GFX_ARCHS} \
|
|
||||||
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
|
|
||||||
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
|
|
||||||
&& python3 setup.py install \
|
|
||||||
&& cd ..; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
|
||||||
# Manually removed it so that later steps of numpy upgrade can continue
|
|
||||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
|
||||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
|
||||||
|
|
||||||
# build triton
|
|
||||||
RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
|
||||||
mkdir -p libs \
|
|
||||||
&& cd libs \
|
|
||||||
&& pip uninstall -y triton \
|
|
||||||
&& git clone https://github.com/ROCm/triton.git \
|
|
||||||
&& cd triton/python \
|
|
||||||
&& pip3 install . \
|
|
||||||
&& cd ../..; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
WORKDIR /vllm-workspace
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip numba
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
|
||||||
pip install -U -r requirements-rocm.txt \
|
|
||||||
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h ./rocm_patch/rocm_bf16.patch \
|
|
||||||
&& python3 setup.py install \
|
|
||||||
&& cp build/lib.linux-x86_64-cpython-39/vllm/_C.cpython-39-x86_64-linux-gnu.so vllm/ \
|
|
||||||
&& cd ..
|
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
|
||||||
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3
|
|
||||||
|
|
||||||
CMD ["/bin/bash"]
|
|
||||||
645
LICENSE
645
LICENSE
@@ -1,7 +1,3 @@
|
|||||||
The vllm_musa from Moore Threads is licensed under the Apache License 2.0 listed below.
|
|
||||||
Copyright (c) 2022-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
|
|
||||||
Terms of the Apache License 2.0
|
|
||||||
-------------------------------------------------------------------------
|
|
||||||
Apache License
|
Apache License
|
||||||
Version 2.0, January 2004
|
Version 2.0, January 2004
|
||||||
http://www.apache.org/licenses/
|
http://www.apache.org/licenses/
|
||||||
@@ -203,644 +199,3 @@ Terms of the Apache License 2.0
|
|||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
|
|
||||||
-------------------------------------------------------------------------
|
|
||||||
The following copyright statements and licenses apply to various open source software/model
|
|
||||||
packages (or portions thereof) that are distributed with this vllm_musa. vllm_musa that
|
|
||||||
includes this file does not necessarily use all the open source software packages referred
|
|
||||||
to below and may also only use portions of a given package. Some open source software
|
|
||||||
packages referred to below may have been modified by Moore Threads Technology Co., Ltd
|
|
||||||
|
|
||||||
-------------------------------------------------------------------------
|
|
||||||
vllm
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
||||||
------------------------------------------------------------------------------------
|
|
||||||
Contains code from https://github.com/punica-ai/punica
|
|
||||||
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright {yyyy} {name of copyright owner}
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
||||||
------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
This product bundles various third-party components under other open source licenses.
|
|
||||||
This section summarizes those components and their licenses. See licenses/
|
|
||||||
for text of these licenses.
|
|
||||||
|
|
||||||
|
|
||||||
Apache-2.0
|
|
||||||
* third_party/nvbench (with LLVM exception)
|
|
||||||
* third_party/flashinfer
|
|
||||||
|
|
||||||
BSD-3-Clause:
|
|
||||||
* third_party/cutlass
|
|
||||||
|
|
||||||
------------------------------------------------------------------------------------
|
|
||||||
Contains code from https://github.com/IST-DASLab/marlin
|
|
||||||
|
|
||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright {yyyy} {name of copyright owner}
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
||||||
------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
This product bundles various third-party components under other open source licenses.
|
|
||||||
This section summarizes those components and their licenses. See licenses/
|
|
||||||
for text of these licenses.
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
include LICENSE
|
include LICENSE
|
||||||
include requirements-common.txt
|
include requirements/common.txt
|
||||||
include requirements-cuda.txt
|
include requirements/cuda.txt
|
||||||
include requirements-rocm.txt
|
include requirements/rocm.txt
|
||||||
include requirements-neuron.txt
|
include requirements/cpu.txt
|
||||||
include requirements-cpu.txt
|
|
||||||
include CMakeLists.txt
|
include CMakeLists.txt
|
||||||
|
|
||||||
recursive-include cmake *
|
recursive-include cmake *
|
||||||
|
|||||||
293
PKG-INFO
Normal file
293
PKG-INFO
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
Metadata-Version: 2.4
|
||||||
|
Name: vllm
|
||||||
|
Version: 0.13.0
|
||||||
|
Summary: A high-throughput and memory-efficient inference and serving engine for LLMs
|
||||||
|
Author: vLLM Team
|
||||||
|
License-Expression: Apache-2.0
|
||||||
|
Project-URL: Homepage, https://github.com/vllm-project/vllm
|
||||||
|
Project-URL: Documentation, https://docs.vllm.ai/en/latest/
|
||||||
|
Project-URL: Slack, https://slack.vllm.ai/
|
||||||
|
Classifier: Programming Language :: Python :: 3.10
|
||||||
|
Classifier: Programming Language :: Python :: 3.11
|
||||||
|
Classifier: Programming Language :: Python :: 3.12
|
||||||
|
Classifier: Programming Language :: Python :: 3.13
|
||||||
|
Classifier: Intended Audience :: Developers
|
||||||
|
Classifier: Intended Audience :: Information Technology
|
||||||
|
Classifier: Intended Audience :: Science/Research
|
||||||
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||||
|
Classifier: Topic :: Scientific/Engineering :: Information Analysis
|
||||||
|
Requires-Python: <3.14,>=3.10
|
||||||
|
Description-Content-Type: text/markdown
|
||||||
|
License-File: LICENSE
|
||||||
|
Requires-Dist: regex
|
||||||
|
Requires-Dist: cachetools
|
||||||
|
Requires-Dist: psutil
|
||||||
|
Requires-Dist: sentencepiece
|
||||||
|
Requires-Dist: numpy
|
||||||
|
Requires-Dist: requests>=2.26.0
|
||||||
|
Requires-Dist: tqdm
|
||||||
|
Requires-Dist: blake3
|
||||||
|
Requires-Dist: py-cpuinfo
|
||||||
|
Requires-Dist: transformers<5,>=4.56.0
|
||||||
|
Requires-Dist: tokenizers>=0.21.1
|
||||||
|
Requires-Dist: protobuf
|
||||||
|
Requires-Dist: fastapi[standard]>=0.115.0
|
||||||
|
Requires-Dist: aiohttp
|
||||||
|
Requires-Dist: openai>=1.99.1
|
||||||
|
Requires-Dist: pydantic>=2.12.0
|
||||||
|
Requires-Dist: prometheus_client>=0.18.0
|
||||||
|
Requires-Dist: pillow
|
||||||
|
Requires-Dist: prometheus-fastapi-instrumentator>=7.0.0
|
||||||
|
Requires-Dist: tiktoken>=0.6.0
|
||||||
|
Requires-Dist: lm-format-enforcer==0.11.3
|
||||||
|
Requires-Dist: llguidance<1.4.0,>=1.3.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" or platform_machine == "s390x" or platform_machine == "ppc64le"
|
||||||
|
Requires-Dist: outlines_core==0.2.11
|
||||||
|
Requires-Dist: diskcache==5.6.3
|
||||||
|
Requires-Dist: lark==1.2.2
|
||||||
|
Requires-Dist: xgrammar==0.1.27; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64" or platform_machine == "s390x" or platform_machine == "ppc64le"
|
||||||
|
Requires-Dist: typing_extensions>=4.10
|
||||||
|
Requires-Dist: filelock>=3.16.1
|
||||||
|
Requires-Dist: partial-json-parser
|
||||||
|
Requires-Dist: pyzmq>=25.0.0
|
||||||
|
Requires-Dist: msgspec
|
||||||
|
Requires-Dist: gguf>=0.17.0
|
||||||
|
Requires-Dist: mistral_common[image]>=1.8.5
|
||||||
|
Requires-Dist: opencv-python-headless>=4.11.0
|
||||||
|
Requires-Dist: pyyaml
|
||||||
|
Requires-Dist: six>=1.16.0; python_version > "3.11"
|
||||||
|
Requires-Dist: setuptools<81.0.0,>=77.0.3; python_version > "3.11"
|
||||||
|
Requires-Dist: einops
|
||||||
|
Requires-Dist: compressed-tensors==0.12.2
|
||||||
|
Requires-Dist: depyf==0.20.0
|
||||||
|
Requires-Dist: cloudpickle
|
||||||
|
Requires-Dist: watchfiles
|
||||||
|
Requires-Dist: python-json-logger
|
||||||
|
Requires-Dist: scipy
|
||||||
|
Requires-Dist: ninja
|
||||||
|
Requires-Dist: pybase64
|
||||||
|
Requires-Dist: cbor2
|
||||||
|
Requires-Dist: ijson
|
||||||
|
Requires-Dist: setproctitle
|
||||||
|
Requires-Dist: openai-harmony>=0.0.3
|
||||||
|
Requires-Dist: anthropic==0.71.0
|
||||||
|
Requires-Dist: model-hosting-container-standards<1.0.0,>=0.1.9
|
||||||
|
Requires-Dist: mcp
|
||||||
|
Requires-Dist: numba==0.61.2
|
||||||
|
Requires-Dist: ray[cgraph]>=2.48.0
|
||||||
|
Requires-Dist: torch==2.9.0
|
||||||
|
Requires-Dist: torchaudio==2.9.0
|
||||||
|
Requires-Dist: torchvision==0.24.0
|
||||||
|
Requires-Dist: flashinfer-python==0.5.3
|
||||||
|
Provides-Extra: bench
|
||||||
|
Requires-Dist: pandas; extra == "bench"
|
||||||
|
Requires-Dist: matplotlib; extra == "bench"
|
||||||
|
Requires-Dist: seaborn; extra == "bench"
|
||||||
|
Requires-Dist: datasets; extra == "bench"
|
||||||
|
Provides-Extra: tensorizer
|
||||||
|
Requires-Dist: tensorizer==2.10.1; extra == "tensorizer"
|
||||||
|
Provides-Extra: fastsafetensors
|
||||||
|
Requires-Dist: fastsafetensors>=0.1.10; extra == "fastsafetensors"
|
||||||
|
Provides-Extra: runai
|
||||||
|
Requires-Dist: runai-model-streamer[gcs,s3]>=0.15.3; extra == "runai"
|
||||||
|
Provides-Extra: audio
|
||||||
|
Requires-Dist: librosa; extra == "audio"
|
||||||
|
Requires-Dist: soundfile; extra == "audio"
|
||||||
|
Requires-Dist: mistral_common[audio]; extra == "audio"
|
||||||
|
Provides-Extra: video
|
||||||
|
Provides-Extra: flashinfer
|
||||||
|
Provides-Extra: petit-kernel
|
||||||
|
Requires-Dist: petit-kernel; extra == "petit-kernel"
|
||||||
|
Dynamic: license-file
|
||||||
|
Dynamic: provides-extra
|
||||||
|
Dynamic: requires-dist
|
||||||
|
|
||||||
|
<!-- markdownlint-disable MD001 MD041 -->
|
||||||
|
<p align="center">
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png">
|
||||||
|
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-light.png" width=55%>
|
||||||
|
</picture>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
<h3 align="center">
|
||||||
|
Easy, fast, and cheap LLM serving for everyone
|
||||||
|
</h3>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||||
|
</p>
|
||||||
|
|
||||||
|
---
|
||||||
|
Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) and [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco for our latest updates on vLLM and to meet the vLLM team! Register now for the largest vLLM community events of the year!
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Latest News* 🔥
|
||||||
|
|
||||||
|
- [2025/11] We hosted [vLLM Bangkok Meetup](https://luma.com/v0f647nv). We explored vLLM and LMCache inference and low-resource language adaptation with speakers from Embedded LLM, AMD, and Red Hat. Please find the meetup slides [here](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing).
|
||||||
|
- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
|
||||||
|
- [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link).
|
||||||
|
- [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6).
|
||||||
|
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
|
||||||
|
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
|
||||||
|
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).
|
||||||
|
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||||
|
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
|
||||||
|
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Previous News</summary>
|
||||||
|
|
||||||
|
- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view).
|
||||||
|
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
|
||||||
|
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||||
|
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||||
|
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||||
|
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
|
||||||
|
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||||
|
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||||
|
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||||
|
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
|
||||||
|
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
|
||||||
|
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||||
|
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||||
|
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||||
|
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||||
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
|
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) with a16z! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||||
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## About
|
||||||
|
|
||||||
|
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||||
|
|
||||||
|
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
|
||||||
|
|
||||||
|
vLLM is fast with:
|
||||||
|
|
||||||
|
- State-of-the-art serving throughput
|
||||||
|
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
||||||
|
- Continuous batching of incoming requests
|
||||||
|
- Fast model execution with CUDA/HIP graph
|
||||||
|
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
|
||||||
|
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
|
||||||
|
- Speculative decoding
|
||||||
|
- Chunked prefill
|
||||||
|
|
||||||
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
|
- Seamless integration with popular Hugging Face models
|
||||||
|
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||||
|
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||||
|
- Streaming outputs
|
||||||
|
- OpenAI-compatible API server
|
||||||
|
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||||
|
- Prefix caching support
|
||||||
|
- Multi-LoRA support
|
||||||
|
|
||||||
|
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||||
|
|
||||||
|
- Transformer-like LLMs (e.g., Llama)
|
||||||
|
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||||
|
- Embedding Models (e.g., E5-Mistral)
|
||||||
|
- Multi-modal LLMs (e.g., LLaVA)
|
||||||
|
|
||||||
|
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||||
|
|
||||||
|
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
||||||
|
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
||||||
|
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
We welcome and value any contributions and collaborations.
|
||||||
|
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
|
||||||
|
|
||||||
|
## Sponsors
|
||||||
|
|
||||||
|
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
|
||||||
|
|
||||||
|
<!-- Note: Please sort them in alphabetical order. -->
|
||||||
|
<!-- Note: Please keep these consistent with docs/community/sponsors.md -->
|
||||||
|
Cash Donations:
|
||||||
|
|
||||||
|
- a16z
|
||||||
|
- Dropbox
|
||||||
|
- Sequoia Capital
|
||||||
|
- Skywork AI
|
||||||
|
- ZhenFund
|
||||||
|
|
||||||
|
Compute Resources:
|
||||||
|
|
||||||
|
- Alibaba Cloud
|
||||||
|
- AMD
|
||||||
|
- Anyscale
|
||||||
|
- Arm
|
||||||
|
- AWS
|
||||||
|
- Crusoe Cloud
|
||||||
|
- Databricks
|
||||||
|
- DeepInfra
|
||||||
|
- Google Cloud
|
||||||
|
- IBM
|
||||||
|
- Intel
|
||||||
|
- Lambda Lab
|
||||||
|
- Nebius
|
||||||
|
- Novita AI
|
||||||
|
- NVIDIA
|
||||||
|
- Red Hat
|
||||||
|
- Replicate
|
||||||
|
- Roblox
|
||||||
|
- RunPod
|
||||||
|
- Trainy
|
||||||
|
- UC Berkeley
|
||||||
|
- UC San Diego
|
||||||
|
- Volcengine
|
||||||
|
|
||||||
|
Slack Sponsor: Anyscale
|
||||||
|
|
||||||
|
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
|
||||||
|
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@inproceedings{kwon2023efficient,
|
||||||
|
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||||
|
author={Woosuk Kwon and Zhuohan Li and Siyuan Zhuang and Ying Sheng and Lianmin Zheng and Cody Hao Yu and Joseph E. Gonzalez and Hao Zhang and Ion Stoica},
|
||||||
|
booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Contact Us
|
||||||
|
|
||||||
|
<!-- --8<-- [start:contact-us] -->
|
||||||
|
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
|
||||||
|
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||||
|
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||||
|
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||||
|
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||||
|
<!-- --8<-- [end:contact-us] -->
|
||||||
|
|
||||||
|
## Media Kit
|
||||||
|
|
||||||
|
- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)
|
||||||
212
README.md
212
README.md
@@ -1,19 +1,8 @@
|
|||||||
# enginex-mthreads-vllm
|
<!-- markdownlint-disable MD001 MD041 -->
|
||||||
|
|
||||||
环境要求:Driver Version:3.0.0-rc-KuaE2.0
|
|
||||||
|
|
||||||
设备型号:摩尔线程(MThreads)S4000
|
|
||||||
|
|
||||||
vllm 版本:v0.8.4
|
|
||||||
|
|
||||||
源码地址:https://github.com/MooreThreads/vllm_musa
|
|
||||||
|
|
||||||
镜像:git.modelhub.org.cn:9443/enginex-mthreads/vllm-musa-qy2-py310:v0.8.4-release
|
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<picture>
|
<picture>
|
||||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-dark.png">
|
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png">
|
||||||
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/source/assets/logos/vllm-logo-text-light.png" width=55%>
|
<img alt="vLLM" src="https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-light.png" width=55%>
|
||||||
</picture>
|
</picture>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -22,104 +11,161 @@ Easy, fast, and cheap LLM serving for everyone
|
|||||||
</h3>
|
</h3>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://vllm.ai"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://discord.gg/jz7wjKhh6g"><b>Discord</b></a> |
|
| <a href="https://docs.vllm.ai"><b>Documentation</b></a> | <a href="https://blog.vllm.ai/"><b>Blog</b></a> | <a href="https://arxiv.org/abs/2309.06180"><b>Paper</b></a> | <a href="https://x.com/vllm_project"><b>Twitter/X</b></a> | <a href="https://discuss.vllm.ai"><b>User Forum</b></a> | <a href="https://slack.vllm.ai"><b>Developer Slack</b></a> |
|
||||||
|
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
*Latest News* 🔥
|
---
|
||||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
Join us at the [PyTorch Conference, October 22-23](https://events.linuxfoundation.org/pytorch-conference/) and [Ray Summit, November 3-5](https://www.anyscale.com/ray-summit/2025) in San Francisco for our latest updates on vLLM and to meet the vLLM team! Register now for the largest vLLM community events of the year!
|
||||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
|
||||||
- [2024/01] Added ROCm 6.0 support to vLLM.
|
|
||||||
- [2023/12] Added ROCm 5.7 support to vLLM.
|
|
||||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
|
||||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
|
||||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
|
||||||
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
|
||||||
- [2023/07] Added support for LLaMA-2! You can run and serve 7B/13B/70B LLaMA-2s on vLLM with a single command!
|
|
||||||
- [2023/06] Serving vLLM On any Cloud with SkyPilot. Check out a 1-click [example](https://github.com/skypilot-org/skypilot/blob/master/llm/vllm) to start the vLLM demo, and the [blog post](https://blog.skypilot.co/serving-llm-24x-faster-on-the-cloud-with-vllm-and-skypilot/) for the story behind vLLM development on the clouds.
|
|
||||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
*Latest News* 🔥
|
||||||
|
|
||||||
|
- [2025/11] We hosted [vLLM Bangkok Meetup](https://luma.com/v0f647nv). We explored vLLM and LMCache inference and low-resource language adaptation with speakers from Embedded LLM, AMD, and Red Hat. Please find the meetup slides [here](https://drive.google.com/drive/folders/1H0DS57F8HQ5q3kSOSoRmucPJWL3E0A_X?usp=sharing).
|
||||||
|
- [2025/11] We hosted [the first vLLM Europe Meetup in Zurich](https://luma.com/0gls27kb) focused on quantization, distributed inference, and reinforcement learning at scale with speakers from Mistral, IBM, and Red Hat. Please find the meetup slides [here](https://docs.google.com/presentation/d/1UC9PTLCHYXQpOmJDSFg6Sljra3iVXzc09DeEI7dnxMc/edit?usp=sharing) and recording [here](https://www.youtube.com/watch?v=6m6ZE6yVEDI)
|
||||||
|
- [2025/11] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/xSrYXjNgr1HbCP4ExYNG1w) focusing on distributed inference and diverse accelerator support with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1nQJ8ZkLSjKxvu36sSHaceVXtttbLvvu-?usp=drive_link).
|
||||||
|
- [2025/10] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/__xb4OyOsImz-9eAVrdlcg) focused on hands-on vLLM inference optimization! Please find the meetup slides [here](https://drive.google.com/drive/folders/1KqwjsFJLfEsC8wlDugnrR61zsWHt94Q6).
|
||||||
|
- [2025/09] We hosted [vLLM Toronto Meetup](https://luma.com/e80e0ymm) focused on tackling inference at scale and speculative decoding with speakers from NVIDIA and Red Hat! Please find the meetup slides [here](https://docs.google.com/presentation/d/1IYJYmJcu9fLpID5N5RbW_vO0XLo0CGOR14IXOjB61V8/edit?usp=sharing).
|
||||||
|
- [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA).
|
||||||
|
- [2025/08] We hosted [vLLM Singapore Meetup](https://www.sginnovate.com/event/vllm-sg-meet). We shared V1 updates, disaggregated serving and MLLM speedups with speakers from Embedded LLM, AMD, WekaIO, and A*STAR. Please find the meetup slides [here](https://drive.google.com/drive/folders/1ncf3GyqLdqFaB6IeB834E5TZJPLAOiXZ?usp=sharing).
|
||||||
|
- [2025/08] We hosted [vLLM Shanghai Meetup](https://mp.weixin.qq.com/s/pDmAXHcN7Iqc8sUKgJgGtg) focusing on building, developing, and integrating with vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1OvLx39wnCGy_WKq8SiVKf7YcxxYI3WCH).
|
||||||
|
- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/).
|
||||||
|
- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html).
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Previous News</summary>
|
||||||
|
|
||||||
|
- [2025/08] We hosted [vLLM Korea Meetup](https://luma.com/cgcgprmh) with Red Hat and Rebellions! We shared the latest advancements in vLLM along with project spotlights from the vLLM Korea community. Please find the meetup slides [here](https://drive.google.com/file/d/1bcrrAE1rxUgx0mjIeOWT6hNe2RefC5Hm/view).
|
||||||
|
- [2025/08] We hosted [vLLM Beijing Meetup](https://mp.weixin.qq.com/s/dgkWg1WFpWGO2jCdTqQHxA) focusing on large-scale LLM deployment! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Pid6NSFLU43DZRi0EaTcPgXsAzDvbBqF) and the recording [here](https://www.chaspark.com/#/live/1166916873711665152).
|
||||||
|
- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing).
|
||||||
|
- [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing).
|
||||||
|
- [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0).
|
||||||
|
- [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted.
|
||||||
|
- [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing).
|
||||||
|
- [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone!
|
||||||
|
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing).
|
||||||
|
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
|
||||||
|
- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users!
|
||||||
|
- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing).
|
||||||
|
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||||
|
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||||
|
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||||
|
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||||
|
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||||
|
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) with a16z! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||||
|
- [2023/08] We would like to express our sincere gratitude to [Andreessen Horowitz](https://a16z.com/2023/08/30/supporting-the-open-source-ai-community/) (a16z) for providing a generous grant to support the open-source development and research of vLLM.
|
||||||
|
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## About
|
## About
|
||||||
|
|
||||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
vLLM is a fast and easy-to-use library for LLM inference and serving.
|
||||||
|
|
||||||
|
Originally developed in the [Sky Computing Lab](https://sky.cs.berkeley.edu) at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry.
|
||||||
|
|
||||||
vLLM is fast with:
|
vLLM is fast with:
|
||||||
|
|
||||||
- State-of-the-art serving throughput
|
- State-of-the-art serving throughput
|
||||||
- Efficient management of attention key and value memory with **PagedAttention**
|
- Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html)
|
||||||
- Continuous batching of incoming requests
|
- Continuous batching of incoming requests
|
||||||
- Fast model execution with CUDA/HIP graph
|
- Fast model execution with CUDA/HIP graph
|
||||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8
|
||||||
- Optimized CUDA kernels
|
- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer
|
||||||
|
- Speculative decoding
|
||||||
|
- Chunked prefill
|
||||||
|
|
||||||
vLLM is flexible and easy to use with:
|
vLLM is flexible and easy to use with:
|
||||||
|
|
||||||
- Seamless integration with popular Hugging Face models
|
- Seamless integration with popular Hugging Face models
|
||||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||||
- Tensor parallelism support for distributed inference
|
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||||
- Streaming outputs
|
- Streaming outputs
|
||||||
- OpenAI-compatible API server
|
- OpenAI-compatible API server
|
||||||
- Support NVIDIA GPUs and AMD GPUs
|
- Support for NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, Arm CPUs, and TPU. Additionally, support for diverse hardware plugins such as Intel Gaudi, IBM Spyre and Huawei Ascend.
|
||||||
- (Experimental) Prefix caching support
|
- Prefix caching support
|
||||||
- (Experimental) Multi-lora support
|
- Multi-LoRA support
|
||||||
|
|
||||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
vLLM seamlessly supports most popular open-source models on HuggingFace, including:
|
||||||
|
|
||||||
- Aquila & Aquila2 (`BAAI/AquilaChat2-7B`, `BAAI/AquilaChat2-34B`, `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc.)
|
- Transformer-like LLMs (e.g., Llama)
|
||||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
|
- Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3)
|
||||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
|
- Embedding Models (e.g., E5-Mistral)
|
||||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
|
- Multi-modal LLMs (e.g., LLaVA)
|
||||||
- Command-R (`CohereForAI/c4ai-command-r-v01`, etc.)
|
|
||||||
- DBRX (`databricks/dbrx-base`, `databricks/dbrx-instruct` etc.)
|
|
||||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
|
|
||||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
|
|
||||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
|
|
||||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
|
|
||||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
|
|
||||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
|
|
||||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
|
|
||||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
|
||||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
|
|
||||||
- Jais (`core42/jais-13b`, `core42/jais-13b-chat`, `core42/jais-30b-v3`, `core42/jais-30b-chat-v3`, etc.)
|
|
||||||
- LLaMA, Llama 2, and Meta Llama 3 (`meta-llama/Meta-Llama-3-8B-Instruct`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
|
||||||
- MiniCPM (`openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, etc.)
|
|
||||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
|
||||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, `mistral-community/Mixtral-8x22B-v0.1`, etc.)
|
|
||||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
|
||||||
- OLMo (`allenai/OLMo-1B-hf`, `allenai/OLMo-7B-hf`, etc.)
|
|
||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
|
||||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
|
||||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
|
||||||
- Phi-3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
|
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
|
||||||
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
|
||||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
|
||||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
|
|
||||||
- Starcoder2(`bigcode/starcoder2-3b`, `bigcode/starcoder2-7b`, `bigcode/starcoder2-15b`, etc.)
|
|
||||||
- Xverse (`xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.)
|
|
||||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
|
|
||||||
|
|
||||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
|
Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html).
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
Install vLLM with `pip` or [from source](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#build-wheel-from-source):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install vllm
|
pip install vllm
|
||||||
```
|
```
|
||||||
|
|
||||||
## Getting Started
|
Visit our [documentation](https://docs.vllm.ai/en/latest/) to learn more.
|
||||||
|
|
||||||
Visit our [documentation](https://vllm.readthedocs.io/en/latest/) to get started.
|
- [Installation](https://docs.vllm.ai/en/latest/getting_started/installation.html)
|
||||||
- [Installation](https://vllm.readthedocs.io/en/latest/getting_started/installation.html)
|
- [Quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html)
|
||||||
- [Quickstart](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html)
|
- [List of Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html)
|
||||||
- [Supported Models](https://vllm.readthedocs.io/en/latest/models/supported_models.html)
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
We welcome and value any contributions and collaborations.
|
We welcome and value any contributions and collaborations.
|
||||||
Please check out [CONTRIBUTING.md](./CONTRIBUTING.md) for how to get involved.
|
Please check out [Contributing to vLLM](https://docs.vllm.ai/en/latest/contributing/index.html) for how to get involved.
|
||||||
|
|
||||||
|
## Sponsors
|
||||||
|
|
||||||
|
vLLM is a community project. Our compute resources for development and testing are supported by the following organizations. Thank you for your support!
|
||||||
|
|
||||||
|
<!-- Note: Please sort them in alphabetical order. -->
|
||||||
|
<!-- Note: Please keep these consistent with docs/community/sponsors.md -->
|
||||||
|
Cash Donations:
|
||||||
|
|
||||||
|
- a16z
|
||||||
|
- Dropbox
|
||||||
|
- Sequoia Capital
|
||||||
|
- Skywork AI
|
||||||
|
- ZhenFund
|
||||||
|
|
||||||
|
Compute Resources:
|
||||||
|
|
||||||
|
- Alibaba Cloud
|
||||||
|
- AMD
|
||||||
|
- Anyscale
|
||||||
|
- Arm
|
||||||
|
- AWS
|
||||||
|
- Crusoe Cloud
|
||||||
|
- Databricks
|
||||||
|
- DeepInfra
|
||||||
|
- Google Cloud
|
||||||
|
- IBM
|
||||||
|
- Intel
|
||||||
|
- Lambda Lab
|
||||||
|
- Nebius
|
||||||
|
- Novita AI
|
||||||
|
- NVIDIA
|
||||||
|
- Red Hat
|
||||||
|
- Replicate
|
||||||
|
- Roblox
|
||||||
|
- RunPod
|
||||||
|
- Trainy
|
||||||
|
- UC Berkeley
|
||||||
|
- UC San Diego
|
||||||
|
- Volcengine
|
||||||
|
|
||||||
|
Slack Sponsor: Anyscale
|
||||||
|
|
||||||
|
We also have an official fundraising venue through [OpenCollective](https://opencollective.com/vllm). We plan to use the fund to support the development, maintenance, and adoption of vLLM.
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs/2309.06180):
|
||||||
|
|
||||||
```bibtex
|
```bibtex
|
||||||
@inproceedings{kwon2023efficient,
|
@inproceedings{kwon2023efficient,
|
||||||
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
title={Efficient Memory Management for Large Language Model Serving with PagedAttention},
|
||||||
@@ -129,6 +175,16 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## vllm with MUSA
|
## Contact Us
|
||||||
|
|
||||||
Please refer to [README_vllm_musa](./README_vllm_musa.md).
|
<!-- --8<-- [start:contact-us] -->
|
||||||
|
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues)
|
||||||
|
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||||
|
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||||
|
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||||
|
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||||
|
<!-- --8<-- [end:contact-us] -->
|
||||||
|
|
||||||
|
## Media Kit
|
||||||
|
|
||||||
|
- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit)
|
||||||
|
|||||||
@@ -1,66 +0,0 @@
|
|||||||
# vllm_musa
|
|
||||||
|
|
||||||
摩尔线程致力于构建完善好用的国产GPU应用生态,自主研发了MUSA架构及软件平台。vllm项目是业界广泛使用的大语言模型的推理和服务引擎,使用CUDA/ROCm提供GPU加速能力。为了方便摩尔线程GPU用户使用vllm框架,我们发起vllm_musa开源项目为vllm提供MUSA加速,让用户可释放摩尔线程GPU的澎湃算力。
|
|
||||||
|
|
||||||
现有的vllm代码不支持摩尔线程GPU作为后端,因此我们新增了MUSA设备后端。vllm_musa接口与官方接口一致,用户无需改动业务代码,开箱即用。
|
|
||||||
|
|
||||||
MUSA的一大优势是CUDA兼容,通过musify工具,我们可以快速将官方代码porting至MUSA软件栈,用户可以根据文档自行升级vllm版本并适配MUSA软件栈。
|
|
||||||
|
|
||||||
## 依赖
|
|
||||||
|
|
||||||
- musa_toolkit >= dev3.0.0
|
|
||||||
- pytorch >= v2.2.0
|
|
||||||
- [torch_musa](https://github.com/MooreThreads/torch_musa) >= v1.3.0
|
|
||||||
- triton >= v2.2.0
|
|
||||||
- ray >= 2.9
|
|
||||||
- vllm v0.4.2
|
|
||||||
|
|
||||||
## 使用
|
|
||||||
### 编译
|
|
||||||
运行 `bash build_musa.sh`
|
|
||||||
### 测试示例
|
|
||||||
```
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
from transformers import AutoTokenizer, LlamaForCausalLM
|
|
||||||
import transformers
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
import torch_musa
|
|
||||||
|
|
||||||
|
|
||||||
model_path = <path_to_llm_model>
|
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"Hello, my name is",
|
|
||||||
"The president of the United States is",
|
|
||||||
"The capital of France is",
|
|
||||||
"The future of AI is",
|
|
||||||
]
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
||||||
llm = LLM(model=model_path, trust_remote_code=True, device="musa")
|
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
# Print the outputs.
|
|
||||||
for output in outputs:
|
|
||||||
prompt = output.prompt
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Porting
|
|
||||||
|
|
||||||
当前仓库porting自vllm v0.4.2版本。如果用户希望使用更高版本的vllm,只需要运行`musa_porting.py`将原生CUDA代码适配到MUSA代码即可。当然随着vllm的迭代可能会有些代码成为漏网之鱼,没有porting成功,用户可自行修改`musa_porting.py`文件中的文本替换规则。从而发挥MUSA强大的CUDA兼容能力。
|
|
||||||
|
|
||||||
### 步骤
|
|
||||||
1. 运行 `python musa_porting.py`
|
|
||||||
2. 将`CMakeLists.txt`中需要编译的文件后缀从`.cu`修改为`.mu`
|
|
||||||
3. 编译运行vllm_musa
|
|
||||||
|
|
||||||
## 贡献
|
|
||||||
|
|
||||||
欢迎广大用户及开发者使用、反馈,助力vllm_musa功能及性能持续完善。
|
|
||||||
|
|
||||||
社区共建,期待广大开发者与我们一道,共同打造MUSA软件生态。我们将陆续推出一系列开源软件MUSA加速项目。
|
|
||||||
90
RELEASE.md
Normal file
90
RELEASE.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# Releasing vLLM
|
||||||
|
|
||||||
|
vLLM releases offer a reliable version of the code base, packaged into a binary format that can be conveniently accessed via PyPI. These releases also serve as key milestones for the development team to communicate with the community about newly available features, improvements, and upcoming changes that could affect users, including potential breaking changes.
|
||||||
|
|
||||||
|
## Release Versioning
|
||||||
|
|
||||||
|
vLLM uses a “right-shifted” versioning scheme where a new patch release is out every 2 weeks. And patch releases contain features and bug fixes (as opposed to semver where patch release contains only backwards-compatible bug fixes). When critical fixes need to be made, special release post1 is released.
|
||||||
|
|
||||||
|
* _major_ major architectural milestone and when incompatible API changes are made, similar to PyTorch 2.0.
|
||||||
|
* _minor_ major features
|
||||||
|
* _patch_ features and backwards-compatible bug fixes
|
||||||
|
* _post1_ or _patch-1_ backwards-compatible bug fixes, either explicit or implicit post release
|
||||||
|
|
||||||
|
## Release Cadence
|
||||||
|
|
||||||
|
Patch release is released on bi-weekly basis. Post release 1-3 days after patch release and uses same branch as patch release.
|
||||||
|
Following is the release cadence for year 2025. All future release dates below are tentative. Please note: Post releases are optional.
|
||||||
|
|
||||||
|
| Release Date | Patch release versions | Post Release versions |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| Jan 2025 | 0.7.0 | --- |
|
||||||
|
| Feb 2025 | 0.7.1, 0.7.2, 0.7.3 | --- |
|
||||||
|
| Mar 2025 | 0.7.4, 0.7.5 | --- |
|
||||||
|
| Apr 2025 | 0.7.6, 0.7.7 | --- |
|
||||||
|
| May 2025 | 0.7.8, 0.7.9 | --- |
|
||||||
|
| Jun 2025 | 0.7.10, 0.7.11 | --- |
|
||||||
|
| Jul 2025 | 0.7.12, 0.7.13 | --- |
|
||||||
|
| Aug 2025 | 0.7.14, 0.7.15 | --- |
|
||||||
|
| Sep 2025 | 0.7.16, 0.7.17 | --- |
|
||||||
|
| Oct 2025 | 0.7.18, 0.7.19 | --- |
|
||||||
|
| Nov 2025 | 0.7.20, 0.7.21 | --- |
|
||||||
|
| Dec 2025 | 0.7.22, 0.7.23 | --- |
|
||||||
|
|
||||||
|
## Release branch
|
||||||
|
|
||||||
|
Each release is built from a dedicated release branch.
|
||||||
|
|
||||||
|
* For _major_, _minor_, _patch_ releases, the release branch cut is performed 1-2 days before release is live.
|
||||||
|
* For post releases, previously cut release branch is reused
|
||||||
|
* Release builds are triggered via push to RC tag like vX.Y.Z-rc1 . This enables us to build and test multiple RCs for each release.
|
||||||
|
* Final tag : vX.Y.Z does not trigger the build but used for Release notes and assets.
|
||||||
|
* After branch cut is created we monitor the main branch for any reverts and apply these reverts to a release branch.
|
||||||
|
|
||||||
|
## Release Cherry-Pick Criteria
|
||||||
|
|
||||||
|
After branch cut, we approach finalizing the release branch with clear criteria on what cherry picks are allowed in. Note: a cherry pick is a process to land a PR in the release branch after branch cut. These are typically limited to ensure that the team has sufficient time to complete a thorough round of testing on a stable code base.
|
||||||
|
|
||||||
|
* Regression fixes - that address functional/performance regression against the most recent release (e.g. 0.7.0 for 0.7.1 release)
|
||||||
|
* Critical fixes - critical fixes for severe issue such as silent incorrectness, backwards compatibility, crashes, deadlocks, (large) memory leaks
|
||||||
|
* Fixes to new features introduced in the most recent release (e.g. 0.7.0 for 0.7.1 release)
|
||||||
|
* Documentation improvements
|
||||||
|
* Release branch specific changes (e.g. change version identifiers or CI fixes)
|
||||||
|
|
||||||
|
Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes.
|
||||||
|
|
||||||
|
## Manual validations
|
||||||
|
|
||||||
|
### E2E Performance Validation
|
||||||
|
|
||||||
|
Before each release, we perform end-to-end performance validation to ensure no regressions are introduced. This validation uses the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) on PyTorch CI.
|
||||||
|
|
||||||
|
**Current Coverage:**
|
||||||
|
|
||||||
|
* Models: Llama3, Llama4, and Mixtral
|
||||||
|
* Hardware: NVIDIA H100 and AMD MI300x
|
||||||
|
* _Note: Coverage may change based on new model releases and hardware availability_
|
||||||
|
|
||||||
|
**Performance Validation Process:**
|
||||||
|
|
||||||
|
**Step 1: Get Access**
|
||||||
|
Request write access to the [pytorch/pytorch-integration-testing](https://github.com/pytorch/pytorch-integration-testing) repository to run the benchmark workflow.
|
||||||
|
|
||||||
|
**Step 2: Review Benchmark Setup**
|
||||||
|
Familiarize yourself with the benchmark configurations:
|
||||||
|
|
||||||
|
* [CUDA setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/cuda)
|
||||||
|
* [ROCm setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/rocm)
|
||||||
|
|
||||||
|
**Step 3: Run the Benchmark**
|
||||||
|
Navigate to the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) and configure:
|
||||||
|
|
||||||
|
* **vLLM branch**: Set to the release branch (e.g., `releases/v0.9.2`)
|
||||||
|
* **vLLM commit**: Set to the RC commit hash
|
||||||
|
|
||||||
|
**Step 4: Review Results**
|
||||||
|
Once the workflow completes, benchmark results will be available on the [vLLM benchmark dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) under the corresponding branch and commit.
|
||||||
|
|
||||||
|
**Step 5: Performance Comparison**
|
||||||
|
Compare the current results against the previous release to verify no performance regressions have occurred. Here is an
|
||||||
|
example of [v0.9.1 vs v0.9.2](https://hud.pytorch.org/benchmark/llms?startTime=Thu%2C%2017%20Apr%202025%2021%3A43%3A50%20GMT&stopTime=Wed%2C%2016%20Jul%202025%2021%3A43%3A50%20GMT&granularity=week&lBranch=releases/v0.9.1&lCommit=b6553be1bc75f046b00046a4ad7576364d03c835&rBranch=releases/v0.9.2&rCommit=a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f&repoName=vllm-project%2Fvllm&benchmarkName=&modelName=All%20Models&backendName=All%20Backends&modeName=All%20Modes&dtypeName=All%20DType&deviceName=All%20Devices&archName=All%20Platforms).
|
||||||
50
SECURITY.md
Normal file
50
SECURITY.md
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# Security Policy
|
||||||
|
|
||||||
|
## Reporting security issues
|
||||||
|
|
||||||
|
Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new).
|
||||||
|
|
||||||
|
## Issue triage
|
||||||
|
|
||||||
|
Reports will then be triaged by the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html).
|
||||||
|
|
||||||
|
## Threat model
|
||||||
|
|
||||||
|
Please see the [Security Guide in the vLLM documentation](https://docs.vllm.ai/en/latest/usage/security.html) for more information on vLLM's security assumptions and recommendations.
|
||||||
|
|
||||||
|
Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models.
|
||||||
|
|
||||||
|
## Issue severity
|
||||||
|
|
||||||
|
We will determine the risk of each issue, taking into account our experience dealing with past issues, versions affected, common defaults, and use cases. We use the following severity categories:
|
||||||
|
|
||||||
|
### CRITICAL Severity
|
||||||
|
|
||||||
|
Vulnerabilities that allow remote attackers to execute arbitrary code, take full control of the system, or significantly compromise confidentiality, integrity, or availability without any interaction or privileges needed, examples include remote code execution via network, deserialization issues that allow exploit chains. Generally those issues which are rated as CVSS ≥ 9.0.
|
||||||
|
|
||||||
|
### HIGH Severity
|
||||||
|
|
||||||
|
Serious security flaws that allow elevated impact—like RCE in specific, limited contexts or significant data loss—but require advanced conditions or some trust, examples include RCE in advanced deployment modes (e.g. multi-node), or high impact issues where some sort of privileged network access is required. These issues typically have CVSS scores between 7.0 and 8.9
|
||||||
|
|
||||||
|
### MODERATE Severity
|
||||||
|
|
||||||
|
Vulnerabilities that cause denial of service or partial disruption, but do not allow arbitrary code execution or data breach and have limited impact. These issues have a CVSS rating between 4.0 and 6.9
|
||||||
|
|
||||||
|
### LOW Severity
|
||||||
|
|
||||||
|
Minor issues such as informational disclosures, logging errors, non-exploitable flaws, or weaknesses that require local or high-privilege access and offer negligible impact. Examples include side channel attacks or hash collisions. These issues often have CVSS scores less than 4.0
|
||||||
|
|
||||||
|
## Prenotification policy
|
||||||
|
|
||||||
|
For certain security issues of CRITICAL, HIGH, or MODERATE severity level, we may prenotify certain organizations or vendors that ship vLLM. The purpose of this prenotification is to allow for a coordinated release of fixes for severe issues.
|
||||||
|
|
||||||
|
* This prenotification will be in the form of a private email notification. It may also include adding security contacts to the GitHub security advisory, typically a few days before release.
|
||||||
|
|
||||||
|
* If you wish to be added to the prenotification group, please send an email copying all the members of the [vulnerability management team](https://docs.vllm.ai/en/latest/contributing/vulnerability_management.html). Each vendor contact will be analyzed on a case-by-case basis.
|
||||||
|
|
||||||
|
* Organizations and vendors who either ship or use vLLM, are eligible to join the prenotification group if they meet at least one of the following qualifications
|
||||||
|
* Substantial internal deployment leveraging the upstream vLLM project.
|
||||||
|
* Established internal security teams and comprehensive compliance measures.
|
||||||
|
* Active and consistent contributions to the upstream vLLM project.
|
||||||
|
|
||||||
|
* We may withdraw organizations from receiving future prenotifications if they release fixes or any other information about issues before they are public. Group membership may also change based on policy refinements for who may be included.
|
||||||
@@ -1,8 +1,20 @@
|
|||||||
# Benchmarking vLLM
|
# Benchmarks
|
||||||
|
|
||||||
## Downloading the ShareGPT dataset
|
This directory used to contain vLLM's benchmark scripts and utilities for performance testing and evaluation.
|
||||||
|
|
||||||
You can download the dataset by running:
|
## Contents
|
||||||
```bash
|
|
||||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
- **Serving benchmarks**: Scripts for testing online inference performance (latency, throughput)
|
||||||
```
|
- **Throughput benchmarks**: Scripts for testing offline batch inference performance
|
||||||
|
- **Specialized benchmarks**: Tools for testing specific features like structured output, prefix caching, long document QA, request prioritization, and multi-modal inference
|
||||||
|
- **Dataset utilities**: Framework for loading and sampling from various benchmark datasets (ShareGPT, HuggingFace datasets, synthetic data, etc.)
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
For detailed usage instructions, examples, and dataset information, see the [Benchmark CLI documentation](https://docs.vllm.ai/en/latest/contributing/benchmarks.html#benchmark-cli).
|
||||||
|
|
||||||
|
For full CLI reference see:
|
||||||
|
|
||||||
|
- <https://docs.vllm.ai/en/latest/cli/bench/latency.html>
|
||||||
|
- <https://docs.vllm.ai/en/latest/cli/bench/serve.html>
|
||||||
|
- <https://docs.vllm.ai/en/latest/cli/bench/throughput.html>
|
||||||
|
|||||||
218
benchmarks/auto_tune/README.md
Normal file
218
benchmarks/auto_tune/README.md
Normal file
@@ -0,0 +1,218 @@
|
|||||||
|
# Automated vLLM Server Parameter Tuning
|
||||||
|
|
||||||
|
This script automates the process of finding the optimal server parameter combination (`max-num-seqs` and `max-num-batched-tokens`) to maximize throughput for a vLLM server. It also supports additional constraints such as E2E latency and prefix cache hit rate.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Prerequisites](#prerequisites)
|
||||||
|
- [Configuration](#configuration)
|
||||||
|
- [How to Run](#how-to-run)
|
||||||
|
- [Example Use Cases](#example-use-cases)
|
||||||
|
- [Output](#output)
|
||||||
|
- [How It Works](#how-it-works)
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
Before running the script, please ensure the following steps are completed:
|
||||||
|
|
||||||
|
1. **Clone vLLM & Set Up Branch**: Clone the vLLM repository and check out to your desired branch.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/vllm-project/vllm.git
|
||||||
|
cd vllm
|
||||||
|
# git checkout <your-branch>
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **Install Environment**: Install or update the correct running environment. For TPU usage, activate your `conda` environment and install the corresponding `torch` and `torch_xla` versions.
|
||||||
|
|
||||||
|
2. **Model Configuration**: If you are using a customized model, ensure its configuration files are correctly placed and accessible.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
You must set the following variables at the top of the script before execution.
|
||||||
|
|
||||||
|
Note: You can also override the default values below via environment variables when running the script.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
MODEL=meta-llama/Llama-3.3-70B-Instruct SYSTEM=TPU TP=8 DOWNLOAD_DIR='' INPUT_LEN=128 OUTPUT_LEN=2048 MAX_MODEL_LEN=2300 MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 NUM_SEQS_LIST="128 256" NUM_BATCHED_TOKENS_LIST="1024 2048 4096" VLLM_LOGGING_LEVEL=DEBUG bash auto_tune.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
| Variable | Description | Example Value |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` |
|
||||||
|
| `MODEL` | **Required.** The Hugging Face model identifier to be served by vllm. | `"meta-llama/Llama-3.1-8B-Instruct"` |
|
||||||
|
| `SYSTEM`| **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` |
|
||||||
|
| `TP` | **Required.** The tensor-parallelism size. | `1` |
|
||||||
|
| `DOWNLOAD_DIR` | **Required.** Directory to download and load model weights from. | `""` (default download path) |
|
||||||
|
| `INPUT_LEN` | **Required.** Request input length. | `4000` |
|
||||||
|
| `OUTPUT_LEN` | **Required.** Request output length. | `16` |
|
||||||
|
| `MAX_MODEL_LEN` | **Required.** Max model length. | `4096` |
|
||||||
|
| `MIN_CACHE_HIT_PCT` | Prefix cache hit rate in percentage (0-100). Set to `0` to disable. | `60` |
|
||||||
|
| `MAX_LATENCY_ALLOWED_MS` | The maximum allowed P99 end-to-end latency in milliseconds. Set to a very large number (e.g., `100000000000`) to effectively ignore the latency constraint. | `500` |
|
||||||
|
| `NUM_SEQS_LIST` | A space-separated string of `max-num-seqs` values to test. | `"128 256"` |
|
||||||
|
| `NUM_BATCHED_TOKENS_LIST` | A space-separated string of `max-num-batched-tokens` values to test. | `"1024 2048 4096"` |
|
||||||
|
|
||||||
|
**Note**: The default `NUM_SEQS_LIST` and `NUM_BATCHED_TOKENS_LIST` are set for medium-sized inputs/outputs. For very short contexts (e.g., 20 input, 20 output tokens), you may need to test larger values for `max-num-seqs`.
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
1. **Configure**: Edit the script and set the variables in the [Configuration](#configuration) section.
|
||||||
|
2. **Execute**: Run the script. Since the process can take a long time, it is highly recommended to use a terminal multiplexer like `tmux` or `screen` to prevent the script from stopping if your connection is lost.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd <FOLDER_OF_THIS_SCRIPT>
|
||||||
|
bash auto_tune.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Please note that the `bash auto_tune.sh` command cannot contain full or partial path with keyword `vllm`, otherwise `pkill -f vllm` command will also kill this script itself.
|
||||||
|
|
||||||
|
## Example Use Cases
|
||||||
|
|
||||||
|
Here are a few examples of how to configure the script for different goals:
|
||||||
|
|
||||||
|
### 1. Maximize Throughput (No Latency Constraint)
|
||||||
|
|
||||||
|
- **Goal**: Find the best `max-num-seqs` and `max-num-batched-tokens` to get the highest possible throughput for 1800 input tokens and 20 output tokens.
|
||||||
|
- **Configuration**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INPUT_LEN=1800
|
||||||
|
OUTPUT_LEN=20
|
||||||
|
MAX_MODEL_LEN=2048
|
||||||
|
MIN_CACHE_HIT_PCT=0
|
||||||
|
MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Maximize Throughput with a Latency Requirement
|
||||||
|
|
||||||
|
- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms.
|
||||||
|
- **Configuration**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INPUT_LEN=1800
|
||||||
|
OUTPUT_LEN=20
|
||||||
|
MAX_MODEL_LEN=2048
|
||||||
|
MIN_CACHE_HIT_PCT=0
|
||||||
|
MAX_LATENCY_ALLOWED_MS=500
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Maximize Throughput with Prefix Caching and Latency Requirements
|
||||||
|
|
||||||
|
- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
|
||||||
|
- **Configuration**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INPUT_LEN=1800
|
||||||
|
OUTPUT_LEN=20
|
||||||
|
MAX_MODEL_LEN=2048
|
||||||
|
MIN_CACHE_HIT_PCT=60
|
||||||
|
MAX_LATENCY_ALLOWED_MS=500
|
||||||
|
```
|
||||||
|
|
||||||
|
## Output
|
||||||
|
|
||||||
|
After the script finishes, you will find the results in a new, timestamped directory created inside `$BASE/auto-benchmark/`.
|
||||||
|
|
||||||
|
- **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run:
|
||||||
|
- `vllm_log_...txt`: The log output from the vLLM server for each parameter combination.
|
||||||
|
- `bm_log_...txt`: The log output from the `vllm bench serve` command for each benchmark run.
|
||||||
|
|
||||||
|
- **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found.
|
||||||
|
|
||||||
|
```text
|
||||||
|
# Example result.txt content
|
||||||
|
hash:a1b2c3d4...
|
||||||
|
max_num_seqs: 128, max_num_batched_tokens: 2048, request_rate: 10.0, e2el: 450.5, throughput: 9.8, goodput: 9.8
|
||||||
|
max_num_seqs: 128, max_num_batched_tokens: 4096 does not meet latency requirement 500
|
||||||
|
...
|
||||||
|
best_max_num_seqs: 256, best_num_batched_tokens: 2048, best_throughput: 12.5, profile saved in: /home/user/vllm/auto-benchmark/2024_08_01_10_30/profile
|
||||||
|
```
|
||||||
|
|
||||||
|
If it cannot find the best parameters, the final row will be `best_max_num_seqs: 0, best_num_batched_tokens: 0, best_throughput: 0`. This can be due to either the server not starting properly, or the latency requirement being too strict.
|
||||||
|
|
||||||
|
- **Profiler Trace**: A directory named `profile` is created inside the log directory. It contains the profiler trace file (e.g., `.xplane.pb` for TPU or a `.json` trace for GPU) from the single best-performing run.
|
||||||
|
|
||||||
|
## How It Works
|
||||||
|
|
||||||
|
The script follows a systematic process to find the optimal parameters:
|
||||||
|
|
||||||
|
1. **Find Max GPU Memory Utilization**: The script first determines the highest safe `gpu-memory-utilization` (starting from 0.98 and decreasing) that does not cause an Out-Of-Memory (OOM) error when launching the server. This ensures the benchmark runs use the maximum available memory without crashing.
|
||||||
|
|
||||||
|
2. **Iterate and Benchmark**: It then enters a nested loop, iterating through every combination of `max-num-seqs` and `max-num-batched-tokens` provided in the configuration lists.
|
||||||
|
|
||||||
|
3. **Latency-Aware Throughput Search**: For each parameter combination:
|
||||||
|
- The vLLM server is started.
|
||||||
|
- A benchmark is first run with an infinite request rate (`--request-rate inf`).
|
||||||
|
- If the resulting P99 E2E latency is within the `MAX_LATENCY_ALLOWED_MS` limit, this throughput is considered the maximum for this configuration.
|
||||||
|
- If the latency is too high, the script performs a search by iteratively decreasing the request rate until the latency constraint is met. This finds the highest sustainable throughput for the given parameters and latency requirement.
|
||||||
|
|
||||||
|
4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far.
|
||||||
|
|
||||||
|
5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard.
|
||||||
|
|
||||||
|
## Batched `auto_tune`
|
||||||
|
|
||||||
|
The `batch_auto_tune.sh` script allows you to run multiple `auto_tune.sh` experiments sequentially from a single configuration file. It iterates through a list of parameter sets, executes `auto_tune.sh` for each, and records the results back into the input file.
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- **jq**: This script requires `jq` to parse the JSON configuration file.
|
||||||
|
- **gcloud**: If you plan to upload results to Google Cloud Storage, the `gcloud` CLI must be installed and authenticated.
|
||||||
|
|
||||||
|
### How to Run
|
||||||
|
|
||||||
|
1. **Create a JSON configuration file**: Create a file (e.g., `runs_config.json`) containing an array of JSON objects. Each object defines the parameters for a single `auto_tune.sh` run.
|
||||||
|
|
||||||
|
2. **Execute the script**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash batch_auto_tune.sh <path_to_json_file> [gcs_upload_path]
|
||||||
|
```
|
||||||
|
|
||||||
|
- `<path_to_json_file>`: **Required.** Path to your JSON configuration file.
|
||||||
|
- `[gcs_upload_path]`: **Optional.** A GCS path (e.g., `gs://my-bucket/benchmark-results`) where the detailed results and profiles for each run will be uploaded. If this is empty, the results will be available on the local filesystem (see the log for `RESULT_FILE=/path/to/results/file.txt`).
|
||||||
|
|
||||||
|
### Configuration File
|
||||||
|
|
||||||
|
The JSON configuration file should contain an array of objects. Each object's keys correspond to the configuration variables for `auto_tune.sh` (see the [Configuration table above](#configuration)). These keys will be converted to uppercase environment variables for each run.
|
||||||
|
|
||||||
|
Here is an example `runs_config.json` with two benchmark configurations:
|
||||||
|
|
||||||
|
```json
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"base": "/home/user",
|
||||||
|
"model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"system": "TPU", # OR GPU
|
||||||
|
"tp": 8,
|
||||||
|
"input_len": 128,
|
||||||
|
"output_len": 2048,
|
||||||
|
"max_model_len": 2300,
|
||||||
|
"num_seqs_list": "128 256",
|
||||||
|
"num_batched_tokens_list": "8192 16384"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"base": "/home/user",
|
||||||
|
"model": "meta-llama/Llama-3.1-70B-Instruct",
|
||||||
|
"system": "TPU", # OR GPU
|
||||||
|
"tp": 8,
|
||||||
|
"input_len": 4000,
|
||||||
|
"output_len": 16,
|
||||||
|
"max_model_len": 4096,
|
||||||
|
"num_seqs_list": "64 128",
|
||||||
|
"num_batched_tokens_list": "4096 8192",
|
||||||
|
"max_latency_allowed_ms": 500
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
### Output
|
||||||
|
|
||||||
|
The script modifies the input JSON file in place, adding the results of each run to the corresponding object. The following fields are added:
|
||||||
|
|
||||||
|
- `run_id`: A unique identifier for the run, derived from the timestamp.
|
||||||
|
- `status`: The outcome of the run (`SUCCESS`, `FAILURE`, or `WARNING_NO_RESULT_FILE`).
|
||||||
|
- `results`: The content of the `result.txt` file from the `auto_tune.sh` run.
|
||||||
|
- `gcs_results`: The GCS URL where the run's artifacts are stored (if a GCS path was provided).
|
||||||
|
|
||||||
|
A summary of successful and failed runs is also printed to the console upon completion.
|
||||||
323
benchmarks/auto_tune/auto_tune.sh
Normal file
323
benchmarks/auto_tune/auto_tune.sh
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# This script aims to tune the best server parameter combinations to maximize throughput for given requirement.
|
||||||
|
# See details in README (benchmarks/auto_tune/README.md).
|
||||||
|
|
||||||
|
TAG=$(date +"%Y_%m_%d_%H_%M")
|
||||||
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
|
VLLM_LOGGING_LEVEL=${VLLM_LOGGING_LEVEL:-INFO}
|
||||||
|
BASE=${BASE:-"$SCRIPT_DIR/../../.."}
|
||||||
|
MODEL=${MODEL:-"meta-llama/Llama-3.1-8B-Instruct"}
|
||||||
|
SYSTEM=${SYSTEM:-"TPU"}
|
||||||
|
TP=${TP:-1}
|
||||||
|
DOWNLOAD_DIR=${DOWNLOAD_DIR:-""}
|
||||||
|
INPUT_LEN=${INPUT_LEN:-4000}
|
||||||
|
OUTPUT_LEN=${OUTPUT_LEN:-16}
|
||||||
|
MAX_MODEL_LEN=${MAX_MODEL_LEN:-4096}
|
||||||
|
MIN_CACHE_HIT_PCT=${MIN_CACHE_HIT_PCT:-0}
|
||||||
|
MAX_LATENCY_ALLOWED_MS=${MAX_LATENCY_ALLOWED_MS:-100000000000}
|
||||||
|
NUM_SEQS_LIST=${NUM_SEQS_LIST:-"128 256"}
|
||||||
|
NUM_BATCHED_TOKENS_LIST=${NUM_BATCHED_TOKENS_LIST:-"512 1024 2048 4096"}
|
||||||
|
HOSTNAME=$(hostname)
|
||||||
|
if [[ -z "$HOSTNAME" ]]; then
|
||||||
|
echo "Error: Failed to determine hostname." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
LOG_FOLDER="$BASE/auto-benchmark/$TAG"
|
||||||
|
RESULT="$LOG_FOLDER/result.txt"
|
||||||
|
PROFILE_PATH="$LOG_FOLDER/profile"
|
||||||
|
|
||||||
|
echo "====================== AUTO TUNE PARAMETERS ===================="
|
||||||
|
echo "SCRIPT_DIR=$SCRIPT_DIR"
|
||||||
|
echo "BASE=$BASE"
|
||||||
|
echo "MODEL=$MODEL"
|
||||||
|
echo "SYSTEM=$SYSTEM"
|
||||||
|
echo "TP=$TP"
|
||||||
|
echo "DOWNLOAD_DIR=$DOWNLOAD_DIR"
|
||||||
|
echo "INPUT_LEN=$INPUT_LEN"
|
||||||
|
echo "OUTPUT_LEN=$OUTPUT_LEN"
|
||||||
|
echo "MAX_MODEL_LEN=$MAX_MODEL_LEN"
|
||||||
|
echo "MIN_CACHE_HIT_PCT=$MIN_CACHE_HIT_PCT"
|
||||||
|
echo "MAX_LATENCY_ALLOWED_MS=$MAX_LATENCY_ALLOWED_MS"
|
||||||
|
echo "NUM_SEQS_LIST=$NUM_SEQS_LIST"
|
||||||
|
echo "NUM_BATCHED_TOKENS_LIST=$NUM_BATCHED_TOKENS_LIST"
|
||||||
|
echo "VLLM_LOGGING_LEVEL=$VLLM_LOGGING_LEVEL"
|
||||||
|
echo "RESULT_FILE=$RESULT"
|
||||||
|
echo "====================== AUTO TUNEPARAMETERS ===================="
|
||||||
|
|
||||||
|
rm -rf $LOG_FOLDER
|
||||||
|
rm -rf $PROFILE_PATH
|
||||||
|
mkdir -p $LOG_FOLDER
|
||||||
|
mkdir -p $PROFILE_PATH
|
||||||
|
|
||||||
|
cd "$BASE/vllm"
|
||||||
|
|
||||||
|
pip install -q datasets
|
||||||
|
|
||||||
|
current_hash=$(git rev-parse HEAD)
|
||||||
|
echo "hash:$current_hash" >> "$RESULT"
|
||||||
|
echo "current_hash: $current_hash"
|
||||||
|
|
||||||
|
TOTAL_LEN=$((INPUT_LEN + OUTPUT_LEN))
|
||||||
|
RED='\033[0;31m'
|
||||||
|
if (( TOTAL_LEN > MAX_MODEL_LEN )); then
|
||||||
|
echo -e "${RED}FAILED: INPUT_LEN($INPUT_LEN) + OUTPUT_LEN($OUTPUT_LEN) = $TOTAL_LEN, which is > MAX_MODEL_LEN = $MAX_MODEL_LEN.\033[0m" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
best_throughput=0
|
||||||
|
best_max_num_seqs=0
|
||||||
|
best_num_batched_tokens=0
|
||||||
|
best_goodput=0
|
||||||
|
best_request_rate=0
|
||||||
|
|
||||||
|
start_server() {
|
||||||
|
local gpu_memory_utilization=$1
|
||||||
|
local max_num_seqs=$2
|
||||||
|
local max_num_batched_tokens=$3
|
||||||
|
local vllm_log=$4
|
||||||
|
local profile_dir=$5
|
||||||
|
|
||||||
|
pkill -if "vllm serve" || true
|
||||||
|
|
||||||
|
# Define the common arguments as a bash array.
|
||||||
|
# Each argument and its value are separate elements.
|
||||||
|
local common_args_array=(
|
||||||
|
"$MODEL"
|
||||||
|
"--disable-log-requests"
|
||||||
|
"--port" "8004"
|
||||||
|
"--host" "$HOSTNAME"
|
||||||
|
"--gpu-memory-utilization" "$gpu_memory_utilization"
|
||||||
|
"--max-num-seqs" "$max_num_seqs"
|
||||||
|
"--max-num-batched-tokens" "$max_num_batched_tokens"
|
||||||
|
"--tensor-parallel-size" "$TP"
|
||||||
|
"--enable-prefix-caching"
|
||||||
|
"--load-format" "dummy"
|
||||||
|
"--download-dir" "$DOWNLOAD_DIR"
|
||||||
|
"--max-model-len" "$MAX_MODEL_LEN"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the array expansion "${common_args_array[@]}"
|
||||||
|
# This correctly passes each element as a separate argument.
|
||||||
|
if [[ -n "$profile_dir" ]]; then
|
||||||
|
# Start server with profiling enabled
|
||||||
|
local profile_config_json="{\"profiler\": \"torch\", \"torch_profiler_dir\": \"$profile_dir\"}"
|
||||||
|
VLLM_SERVER_DEV_MODE=1 \
|
||||||
|
vllm serve --profiler-config "$profile_config_json" "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||||
|
else
|
||||||
|
# Start server without profiling
|
||||||
|
VLLM_SERVER_DEV_MODE=1 \
|
||||||
|
vllm serve "${common_args_array[@]}" > "$vllm_log" 2>&1 &
|
||||||
|
fi
|
||||||
|
local server_pid=$!
|
||||||
|
|
||||||
|
# wait for 10 minutes...
|
||||||
|
server_started=0
|
||||||
|
for i in {1..60}; do
|
||||||
|
# This line checks whether the server is still alive or not,
|
||||||
|
# since that we should always have permission to send signal to the server process.
|
||||||
|
kill -0 $server_pid 2> /dev/null || break
|
||||||
|
|
||||||
|
RESPONSE=$(curl -s -X GET "http://${HOSTNAME}:8004/health" -w "%{http_code}" -o /dev/stdout)
|
||||||
|
STATUS_CODE=$(echo "$RESPONSE" | tail -n 1)
|
||||||
|
if [[ "$STATUS_CODE" -eq 200 ]]; then
|
||||||
|
server_started=1
|
||||||
|
break
|
||||||
|
else
|
||||||
|
sleep 10
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if (( ! server_started )); then
|
||||||
|
echo "server did not start within 10 minutes or crashed. Please check server log at $vllm_log".
|
||||||
|
return 1
|
||||||
|
else
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
run_benchmark() {
|
||||||
|
local max_num_seqs=$1
|
||||||
|
local max_num_batched_tokens=$2
|
||||||
|
local gpu_memory_utilization=$3
|
||||||
|
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||||
|
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
|
||||||
|
echo "vllm_log: $vllm_log"
|
||||||
|
echo
|
||||||
|
rm -f $vllm_log
|
||||||
|
pkill -if "vllm serve" || true
|
||||||
|
|
||||||
|
echo "starting server..."
|
||||||
|
# Call start_server without a profile_dir to avoid profiling overhead
|
||||||
|
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log ""
|
||||||
|
result=$?
|
||||||
|
if [[ "$result" -eq 1 ]]; then
|
||||||
|
echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||||
|
else
|
||||||
|
echo "server started."
|
||||||
|
fi
|
||||||
|
echo
|
||||||
|
|
||||||
|
echo "run benchmark test..."
|
||||||
|
meet_latency_requirement=0
|
||||||
|
# get a basic qps by using request-rate inf
|
||||||
|
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
|
||||||
|
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||||
|
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||||
|
# --profile flag is removed from this call
|
||||||
|
vllm bench serve \
|
||||||
|
--backend vllm \
|
||||||
|
--model $MODEL \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input-len $adjusted_input_len \
|
||||||
|
--random-output-len $OUTPUT_LEN \
|
||||||
|
--ignore-eos \
|
||||||
|
--disable-tqdm \
|
||||||
|
--request-rate inf \
|
||||||
|
--percentile-metrics ttft,tpot,itl,e2el \
|
||||||
|
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||||
|
--num-prompts 1000 \
|
||||||
|
--random-prefix-len $prefix_len \
|
||||||
|
--host "$HOSTNAME" \
|
||||||
|
--port 8004 &> "$bm_log"
|
||||||
|
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||||
|
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||||
|
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||||
|
|
||||||
|
if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then
|
||||||
|
meet_latency_requirement=1
|
||||||
|
request_rate=inf
|
||||||
|
fi
|
||||||
|
|
||||||
|
if (( ! meet_latency_requirement )); then
|
||||||
|
# start from request-rate as int(throughput) + 1
|
||||||
|
request_rate=$((${throughput%.*} + 1))
|
||||||
|
while ((request_rate > 0)); do
|
||||||
|
# clear prefix cache
|
||||||
|
curl -X POST http://${HOSTNAME}:8004/reset_prefix_cache
|
||||||
|
sleep 5
|
||||||
|
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
|
||||||
|
vllm bench serve \
|
||||||
|
--backend vllm \
|
||||||
|
--model $MODEL \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input-len $adjusted_input_len \
|
||||||
|
--random-output-len $OUTPUT_LEN \
|
||||||
|
--ignore-eos \
|
||||||
|
--disable-tqdm \
|
||||||
|
--request-rate $request_rate \
|
||||||
|
--percentile-metrics ttft,tpot,itl,e2el \
|
||||||
|
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||||
|
--num-prompts 100 \
|
||||||
|
--random-prefix-len $prefix_len \
|
||||||
|
--host "$HOSTNAME" \
|
||||||
|
--port 8004 &> "$bm_log"
|
||||||
|
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||||
|
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||||
|
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||||
|
if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then
|
||||||
|
meet_latency_requirement=1
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
request_rate=$((request_rate-1))
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
# write the results and update the best result.
|
||||||
|
if ((meet_latency_requirement)); then
|
||||||
|
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, throughput: $throughput, goodput: $goodput"
|
||||||
|
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, throughput: $throughput, goodput: $goodput" >> "$RESULT"
|
||||||
|
if (( $(echo "$throughput > $best_throughput" | bc -l) )); then
|
||||||
|
best_throughput=$throughput
|
||||||
|
best_max_num_seqs=$max_num_seqs
|
||||||
|
best_num_batched_tokens=$max_num_batched_tokens
|
||||||
|
best_goodput=$goodput
|
||||||
|
best_request_rate=$request_rate
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
|
||||||
|
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" >> "$RESULT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||||
|
|
||||||
|
pkill -if "vllm serve" || true
|
||||||
|
sleep 10
|
||||||
|
echo "===================="
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
read -r -a num_seqs_list <<< "$NUM_SEQS_LIST"
|
||||||
|
read -r -a num_batched_tokens_list <<< "$NUM_BATCHED_TOKENS_LIST"
|
||||||
|
|
||||||
|
# first find out the max gpu-memory-utilization without HBM OOM.
|
||||||
|
gpu_memory_utilization=0.98
|
||||||
|
find_gpu_memory_utilization=0
|
||||||
|
while (( $(echo "$gpu_memory_utilization >= 0.9" | bc -l) )); do
|
||||||
|
# Pass empty string for profile_dir argument
|
||||||
|
start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" ""
|
||||||
|
result=$?
|
||||||
|
if [[ "$result" -eq 0 ]]; then
|
||||||
|
find_gpu_memory_utilization=1
|
||||||
|
break
|
||||||
|
else
|
||||||
|
gpu_memory_utilization=$(echo "$gpu_memory_utilization - 0.01" | bc)
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ "$find_gpu_memory_utilization" -eq 1 ]]; then
|
||||||
|
echo "Using gpu_memory_utilization=$gpu_memory_utilization to serve model."
|
||||||
|
else
|
||||||
|
echo "Cannot find a proper gpu_memory_utilization over 0.9 to serve the model, please check logs in $LOG_FOLDER."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
for num_seqs in "${num_seqs_list[@]}"; do
|
||||||
|
for num_batched_tokens in "${num_batched_tokens_list[@]}"; do
|
||||||
|
run_benchmark $num_seqs $num_batched_tokens $gpu_memory_utilization
|
||||||
|
done
|
||||||
|
done
|
||||||
|
echo "finish permutations"
|
||||||
|
|
||||||
|
# =================================================================================
|
||||||
|
# FINAL PROFILING RUN FOR THE BEST CONFIGURATION
|
||||||
|
# =================================================================================
|
||||||
|
if (( $(echo "$best_throughput > 0" | bc -l) )); then
|
||||||
|
echo
|
||||||
|
echo "Benchmark tuning finished. Now running profiling on the best configuration found..."
|
||||||
|
echo "Best config: max_num_seqs: $best_max_num_seqs, max_num_batched_tokens: $best_num_batched_tokens, throughput: $best_throughput"
|
||||||
|
echo
|
||||||
|
|
||||||
|
vllm_log="$LOG_FOLDER/vllm_log_BEST_PROFILE.txt"
|
||||||
|
bm_log="$LOG_FOLDER/bm_log_BEST_PROFILE.txt"
|
||||||
|
|
||||||
|
# Start server with the best params and profiling ENABLED
|
||||||
|
echo "Starting server for profiling..."
|
||||||
|
start_server $gpu_memory_utilization $best_max_num_seqs $best_num_batched_tokens "$vllm_log" "$PROFILE_PATH"
|
||||||
|
|
||||||
|
# Run benchmark with the best params and the --profile flag
|
||||||
|
echo "Running benchmark with profiling..."
|
||||||
|
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||||
|
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||||
|
vllm bench serve \
|
||||||
|
--backend vllm \
|
||||||
|
--model $MODEL \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input-len $adjusted_input_len \
|
||||||
|
--random-output-len $OUTPUT_LEN \
|
||||||
|
--ignore-eos \
|
||||||
|
--disable-tqdm \
|
||||||
|
--request-rate $best_request_rate \
|
||||||
|
--percentile-metrics ttft,tpot,itl,e2el \
|
||||||
|
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||||
|
--num-prompts 100 \
|
||||||
|
--random-prefix-len $prefix_len \
|
||||||
|
--host "$HOSTNAME" \
|
||||||
|
--port 8004 \
|
||||||
|
--profile &> "$bm_log"
|
||||||
|
else
|
||||||
|
echo "No configuration met the latency requirements. Skipping final profiling run."
|
||||||
|
fi
|
||||||
|
pkill -if "vllm serve" || true
|
||||||
|
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
||||||
|
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
||||||
128
benchmarks/auto_tune/batch_auto_tune.sh
Executable file
128
benchmarks/auto_tune/batch_auto_tune.sh
Executable file
@@ -0,0 +1,128 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
INPUT_JSON="$1"
|
||||||
|
GCS_PATH="$2" # Optional GCS path for uploading results for each run
|
||||||
|
|
||||||
|
SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)
|
||||||
|
AUTOTUNE_SCRIPT="$SCRIPT_DIR/auto_tune.sh"
|
||||||
|
|
||||||
|
if [[ -z "$INPUT_JSON" ]]; then
|
||||||
|
echo "Error: Input JSON file not provided."
|
||||||
|
echo "Usage: $0 <path_to_json_file> [gcs_upload_path]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ ! -f "$INPUT_JSON" ]]; then
|
||||||
|
echo "Error: File not found at '$INPUT_JSON'"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! command -v jq &> /dev/null; then
|
||||||
|
echo "Error: 'jq' command not found. Please install jq to process the JSON input."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -n "$GCS_PATH" ]] && ! command -v gcloud &> /dev/null; then
|
||||||
|
echo "Error: 'gcloud' command not found, but a GCS_PATH was provided."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
SUCCESS_COUNT=0
|
||||||
|
FAILURE_COUNT=0
|
||||||
|
FAILED_RUNS=()
|
||||||
|
SCRIPT_START_TIME=$(date +%s)
|
||||||
|
|
||||||
|
json_content=$(cat "$INPUT_JSON")
|
||||||
|
if ! num_runs=$(echo "$json_content" | jq 'length'); then
|
||||||
|
echo "Error: Invalid JSON in $INPUT_JSON. 'jq' failed to get array length." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Found $num_runs benchmark configurations in $INPUT_JSON."
|
||||||
|
echo "Starting benchmark runs..."
|
||||||
|
echo "--------------------------------------------------"
|
||||||
|
|
||||||
|
for i in $(seq 0 $(($num_runs - 1))); do
|
||||||
|
run_object=$(echo "$json_content" | jq ".[$i]")
|
||||||
|
|
||||||
|
RUN_START_TIME=$(date +%s)
|
||||||
|
ENV_VARS_ARRAY=()
|
||||||
|
# Dynamically create env vars from the JSON object's keys
|
||||||
|
for key in $(echo "$run_object" | jq -r 'keys_unsorted[]'); do
|
||||||
|
value=$(echo "$run_object" | jq -r ".$key")
|
||||||
|
var_name=$(echo "$key" | tr '[:lower:]' '[:upper:]' | tr -cd 'A-Z0-9_')
|
||||||
|
ENV_VARS_ARRAY+=("${var_name}=${value}")
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "Executing run #$((i+1))/$num_runs with parameters: ${ENV_VARS_ARRAY[*]}"
|
||||||
|
|
||||||
|
# Execute auto_tune.sh and capture output
|
||||||
|
RUN_OUTPUT_FILE=$(mktemp)
|
||||||
|
if env "${ENV_VARS_ARRAY[@]}" bash "$AUTOTUNE_SCRIPT" > >(tee -a "$RUN_OUTPUT_FILE") 2>&1; then
|
||||||
|
STATUS="SUCCESS"
|
||||||
|
((SUCCESS_COUNT++))
|
||||||
|
else
|
||||||
|
STATUS="FAILURE"
|
||||||
|
((FAILURE_COUNT++))
|
||||||
|
FAILED_RUNS+=("Run #$((i+1)): $(echo $run_object | jq -c .)")
|
||||||
|
fi
|
||||||
|
|
||||||
|
RUN_OUTPUT=$(<"$RUN_OUTPUT_FILE")
|
||||||
|
rm "$RUN_OUTPUT_FILE"
|
||||||
|
|
||||||
|
# Parse results and optionally upload them to GCS
|
||||||
|
RUN_ID=""
|
||||||
|
RESULTS=""
|
||||||
|
GCS_RESULTS_URL=""
|
||||||
|
if [[ "$STATUS" == "SUCCESS" ]]; then
|
||||||
|
RESULT_FILE_PATH=$(echo "$RUN_OUTPUT" | grep 'RESULT_FILE=' | tail -n 1 | cut -d'=' -f2 | tr -s '/' || true)
|
||||||
|
|
||||||
|
if [[ -n "$RESULT_FILE_PATH" && -f "$RESULT_FILE_PATH" ]]; then
|
||||||
|
RUN_ID=$(basename "$(dirname "$RESULT_FILE_PATH")")
|
||||||
|
RESULT_DIR=$(dirname "$RESULT_FILE_PATH")
|
||||||
|
RESULTS=$(cat "$RESULT_FILE_PATH")
|
||||||
|
|
||||||
|
if [[ -n "$GCS_PATH" ]]; then
|
||||||
|
GCS_RESULTS_URL="${GCS_PATH}/${RUN_ID}"
|
||||||
|
echo "Uploading results to GCS..."
|
||||||
|
if gcloud storage rsync --recursive "$RESULT_DIR/" "$GCS_RESULTS_URL"; then
|
||||||
|
echo "GCS upload successful."
|
||||||
|
else
|
||||||
|
echo "Warning: GCS upload failed for RUN_ID $RUN_ID."
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Warning: Could not find result file for a successful run."
|
||||||
|
STATUS="WARNING_NO_RESULT_FILE"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Add the results back into the JSON object for this run
|
||||||
|
json_content=$(echo "$json_content" | jq --argjson i "$i" --arg run_id "$RUN_ID" --arg status "$STATUS" --arg results "$RESULTS" --arg gcs_results "$GCS_RESULTS_URL" \
|
||||||
|
'.[$i] += {run_id: $run_id, status: $status, results: $results, gcs_results: $gcs_results}')
|
||||||
|
|
||||||
|
RUN_END_TIME=$(date +%s)
|
||||||
|
echo "Run finished in $((RUN_END_TIME - RUN_START_TIME)) seconds. Status: $STATUS"
|
||||||
|
echo "--------------------------------------------------"
|
||||||
|
|
||||||
|
# Save intermediate progress back to the file
|
||||||
|
echo "$json_content" > "$INPUT_JSON.tmp" && mv "$INPUT_JSON.tmp" "$INPUT_JSON"
|
||||||
|
|
||||||
|
done
|
||||||
|
|
||||||
|
SCRIPT_END_TIME=$(date +%s)
|
||||||
|
echo "All benchmark runs completed in $((SCRIPT_END_TIME - SCRIPT_START_TIME)) seconds."
|
||||||
|
echo
|
||||||
|
echo "====================== SUMMARY ======================"
|
||||||
|
echo "Successful runs: $SUCCESS_COUNT"
|
||||||
|
echo "Failed runs: $FAILURE_COUNT"
|
||||||
|
echo "==================================================="
|
||||||
|
|
||||||
|
if [[ $FAILURE_COUNT -gt 0 ]]; then
|
||||||
|
echo "Details of failed runs (see JSON file for full parameters):"
|
||||||
|
for failed in "${FAILED_RUNS[@]}"; do
|
||||||
|
echo " - $failed"
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Updated results have been saved to '$INPUT_JSON'."
|
||||||
@@ -1,13 +1,21 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import io
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import huggingface_hub.constants
|
||||||
from tqdm.asyncio import tqdm
|
from tqdm.asyncio import tqdm
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||||
|
|
||||||
|
# NOTE(simon): do not import vLLM here so the benchmark script
|
||||||
|
# can run without vLLM installed.
|
||||||
|
|
||||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
||||||
|
|
||||||
@@ -19,8 +27,13 @@ class RequestFuncInput:
|
|||||||
prompt_len: int
|
prompt_len: int
|
||||||
output_len: int
|
output_len: int
|
||||||
model: str
|
model: str
|
||||||
best_of: int = 1
|
model_name: str | None = None
|
||||||
use_beam_search: bool = False
|
logprobs: int | None = None
|
||||||
|
extra_body: dict | None = None
|
||||||
|
multi_modal_content: dict | list[dict] | None = None
|
||||||
|
ignore_eos: bool = False
|
||||||
|
language: str | None = None
|
||||||
|
request_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -28,49 +41,65 @@ class RequestFuncOutput:
|
|||||||
generated_text: str = ""
|
generated_text: str = ""
|
||||||
success: bool = False
|
success: bool = False
|
||||||
latency: float = 0.0
|
latency: float = 0.0
|
||||||
|
output_tokens: int = 0
|
||||||
ttft: float = 0.0 # Time to first token
|
ttft: float = 0.0 # Time to first token
|
||||||
itl: List[float] = field(
|
itl: list[float] = field(default_factory=list) # list of inter-token latencies
|
||||||
default_factory=list) # List of inter-token latencies
|
tpot: float = 0.0 # avg next-token latencies
|
||||||
prompt_len: int = 0
|
prompt_len: int = 0
|
||||||
error: str = ""
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
async def async_request_tgi(
|
async def async_request_tgi(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
pbar: Optional[tqdm] = None,
|
pbar: tqdm | None = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith("generate_stream")
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(
|
||||||
assert not request_func_input.use_beam_search
|
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||||
|
) as session:
|
||||||
params = {
|
params = {
|
||||||
"best_of": request_func_input.best_of,
|
|
||||||
"max_new_tokens": request_func_input.output_len,
|
"max_new_tokens": request_func_input.output_len,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||||
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
"top_p": 0.99, # TGI does not accept 1.0 top_p.
|
||||||
|
"truncate": request_func_input.prompt_len,
|
||||||
|
"ignore_eos_token": request_func_input.ignore_eos,
|
||||||
}
|
}
|
||||||
payload = {
|
payload = {
|
||||||
"inputs": request_func_input.prompt,
|
"inputs": request_func_input.prompt,
|
||||||
"parameters": params,
|
"parameters": params,
|
||||||
}
|
}
|
||||||
|
headers = None
|
||||||
|
if request_func_input.request_id:
|
||||||
|
headers = {"x-request-id": request_func_input.request_id}
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
if request_func_input.ignore_eos:
|
||||||
|
output.output_tokens = request_func_input.output_len
|
||||||
|
else:
|
||||||
|
output.output_tokens = None
|
||||||
|
|
||||||
ttft = 0.0
|
ttft = 0.0
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload) as response:
|
async with session.post(
|
||||||
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
# NOTE: Sometimes TGI returns a ping response without
|
||||||
"data:")
|
# any data, we should skip it.
|
||||||
|
if chunk_bytes.startswith(":"):
|
||||||
|
continue
|
||||||
|
chunk = chunk_bytes.removeprefix("data:")
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
@@ -81,14 +110,16 @@ async def async_request_tgi(
|
|||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
output.latency = most_recent_timestamp - st
|
output.latency = most_recent_timestamp - st
|
||||||
output.success = True
|
output.success = True
|
||||||
output.generated_text = data["generated_text"]
|
output.generated_text = data["generated_text"]
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
except Exception:
|
except Exception:
|
||||||
output.success = False
|
output.success = False
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
@@ -101,14 +132,14 @@ async def async_request_tgi(
|
|||||||
|
|
||||||
async def async_request_trt_llm(
|
async def async_request_trt_llm(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
pbar: Optional[tqdm] = None,
|
pbar: tqdm | None = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith("generate_stream")
|
assert api_url.endswith("generate_stream")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(
|
||||||
assert not request_func_input.use_beam_search
|
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||||
assert request_func_input.best_of == 1
|
) as session:
|
||||||
payload = {
|
payload = {
|
||||||
"accumulate_tokens": True,
|
"accumulate_tokens": True,
|
||||||
"text_input": request_func_input.prompt,
|
"text_input": request_func_input.prompt,
|
||||||
@@ -117,6 +148,11 @@ async def async_request_trt_llm(
|
|||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
|
if request_func_input.ignore_eos:
|
||||||
|
payload["min_length"] = request_func_input.output_len
|
||||||
|
headers = None
|
||||||
|
if request_func_input.request_id:
|
||||||
|
headers = {"x-request-id": request_func_input.request_id}
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
@@ -124,28 +160,28 @@ async def async_request_trt_llm(
|
|||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload) as response:
|
async with session.post(
|
||||||
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data:")
|
||||||
"data:")
|
|
||||||
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
output.generated_text += data["text_output"]
|
output.generated_text += data["text_output"]
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if ttft == 0.0:
|
||||||
ttft = time.perf_counter() - st
|
ttft = timestamp - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
@@ -167,18 +203,27 @@ async def async_request_trt_llm(
|
|||||||
|
|
||||||
async def async_request_deepspeed_mii(
|
async def async_request_deepspeed_mii(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
pbar: Optional[tqdm] = None,
|
pbar: tqdm | None = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
api_url = request_func_input.api_url
|
||||||
assert request_func_input.best_of == 1
|
assert api_url.endswith(("completions", "profile")), (
|
||||||
assert not request_func_input.use_beam_search
|
"OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||||
|
) as session:
|
||||||
payload = {
|
payload = {
|
||||||
|
"model": request_func_input.model,
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
"temperature": 0.01, # deepspeed-mii does not accept 0.0 temp.
|
||||||
"top_p": 1.0,
|
"top_p": 1.0,
|
||||||
}
|
}
|
||||||
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
if request_func_input.request_id:
|
||||||
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
@@ -189,12 +234,22 @@ async def async_request_deepspeed_mii(
|
|||||||
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
async with session.post(url=request_func_input.api_url,
|
async with session.post(
|
||||||
json=payload) as response:
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
parsed_resp = await response.json()
|
parsed_resp = await response.json()
|
||||||
output.latency = time.perf_counter() - st
|
output.latency = time.perf_counter() - st
|
||||||
output.generated_text = parsed_resp["text"][0]
|
if "choices" in parsed_resp:
|
||||||
|
output.generated_text = parsed_resp["choices"][0]["text"]
|
||||||
|
elif "text" in parsed_resp:
|
||||||
|
output.generated_text = parsed_resp["text"][0]
|
||||||
|
else:
|
||||||
|
output.error = (
|
||||||
|
"Unexpected response format: "
|
||||||
|
"neither 'choices' nor 'text' found"
|
||||||
|
)
|
||||||
|
output.success = False
|
||||||
output.success = True
|
output.success = True
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
@@ -211,152 +266,91 @@ async def async_request_deepspeed_mii(
|
|||||||
|
|
||||||
async def async_request_openai_completions(
|
async def async_request_openai_completions(
|
||||||
request_func_input: RequestFuncInput,
|
request_func_input: RequestFuncInput,
|
||||||
pbar: Optional[tqdm] = None,
|
pbar: tqdm | None = None,
|
||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
api_url = request_func_input.api_url
|
api_url = request_func_input.api_url
|
||||||
assert api_url.endswith(
|
assert api_url.endswith(("completions", "profile")), (
|
||||||
"v1/completions"
|
"OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||||
), "OpenAI Completions API URL must end with 'v1/completions'."
|
)
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
async with aiohttp.ClientSession(
|
||||||
assert not request_func_input.use_beam_search
|
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||||
|
) as session:
|
||||||
payload = {
|
payload = {
|
||||||
"model": request_func_input.model,
|
"model": request_func_input.model_name
|
||||||
|
if request_func_input.model_name
|
||||||
|
else request_func_input.model,
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"best_of": request_func_input.best_of,
|
"repetition_penalty": 1.0,
|
||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
"stream_options": {
|
||||||
|
"include_usage": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
headers = {
|
if request_func_input.ignore_eos:
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||||
}
|
if request_func_input.extra_body:
|
||||||
|
payload.update(request_func_input.extra_body)
|
||||||
|
headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
|
||||||
|
if request_func_input.request_id:
|
||||||
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
output = RequestFuncOutput()
|
||||||
output.prompt_len = request_func_input.prompt_len
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
generated_text = ""
|
generated_text = ""
|
||||||
ttft = 0.0
|
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
most_recent_timestamp = st
|
most_recent_timestamp = st
|
||||||
try:
|
try:
|
||||||
async with session.post(url=api_url, json=payload,
|
async with session.post(
|
||||||
headers=headers) as response:
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
|
first_chunk_received = False
|
||||||
async for chunk_bytes in response.content:
|
async for chunk_bytes in response.content:
|
||||||
chunk_bytes = chunk_bytes.strip()
|
chunk_bytes = chunk_bytes.strip()
|
||||||
if not chunk_bytes:
|
if not chunk_bytes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
"data: ")
|
if chunk != "[DONE]":
|
||||||
if chunk == "[DONE]":
|
|
||||||
latency = time.perf_counter() - st
|
|
||||||
else:
|
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk)
|
||||||
|
|
||||||
if data["choices"][0]["text"]:
|
# NOTE: Some completion API might have a last
|
||||||
|
# usage summary response without a token so we
|
||||||
|
# want to check a token was generated
|
||||||
|
if choices := data.get("choices"):
|
||||||
|
# Note that text could be empty here
|
||||||
|
# e.g. for special tokens
|
||||||
|
text = choices[0].get("text")
|
||||||
timestamp = time.perf_counter()
|
timestamp = time.perf_counter()
|
||||||
# First token
|
# First token
|
||||||
if ttft == 0.0:
|
if not first_chunk_received:
|
||||||
ttft = time.perf_counter() - st
|
first_chunk_received = True
|
||||||
output.ttft = ttft
|
|
||||||
|
|
||||||
# Decoding phase
|
|
||||||
# NOTE: Some completion API might have a last
|
|
||||||
# usage summary response without a token so we
|
|
||||||
# do not want to include as inter-token-latency
|
|
||||||
elif data.get("usage", None) is None:
|
|
||||||
output.itl.append(timestamp -
|
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
|
||||||
generated_text += data["choices"][0]["text"]
|
|
||||||
|
|
||||||
output.generated_text = generated_text
|
|
||||||
output.success = True
|
|
||||||
output.latency = latency
|
|
||||||
except Exception:
|
|
||||||
output.success = False
|
|
||||||
exc_info = sys.exc_info()
|
|
||||||
output.error = "".join(traceback.format_exception(*exc_info))
|
|
||||||
|
|
||||||
if pbar:
|
|
||||||
pbar.update(1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
async def async_request_openai_chat_completions(
|
|
||||||
request_func_input: RequestFuncInput,
|
|
||||||
pbar: Optional[tqdm] = None,
|
|
||||||
) -> RequestFuncOutput:
|
|
||||||
api_url = request_func_input.api_url
|
|
||||||
assert api_url.endswith(
|
|
||||||
"v1/chat/completions"
|
|
||||||
), "OpenAI Chat Completions API URL must end with 'v1/chat/completions'."
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
||||||
assert not request_func_input.use_beam_search
|
|
||||||
payload = {
|
|
||||||
"model": request_func_input.model,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": request_func_input.prompt,
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"temperature": 0.0,
|
|
||||||
"max_tokens": request_func_input.output_len,
|
|
||||||
"stream": True,
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
||||||
}
|
|
||||||
|
|
||||||
output = RequestFuncOutput()
|
|
||||||
output.prompt_len = request_func_input.prompt_len
|
|
||||||
|
|
||||||
generated_text = ""
|
|
||||||
ttft = 0.0
|
|
||||||
st = time.perf_counter()
|
|
||||||
most_recent_timestamp = st
|
|
||||||
try:
|
|
||||||
async with session.post(url=api_url, json=payload,
|
|
||||||
headers=headers) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
async for chunk_bytes in response.content:
|
|
||||||
chunk_bytes = chunk_bytes.strip()
|
|
||||||
if not chunk_bytes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"),
|
|
||||||
"data: ")
|
|
||||||
if chunk == "[DONE]":
|
|
||||||
latency = time.perf_counter() - st
|
|
||||||
else:
|
|
||||||
timestamp = time.perf_counter()
|
|
||||||
data = json.loads(chunk)
|
|
||||||
|
|
||||||
delta = data["choices"][0]["delta"]
|
|
||||||
if delta.get("content", None):
|
|
||||||
# First token
|
|
||||||
if ttft == 0.0:
|
|
||||||
ttft = time.perf_counter() - st
|
ttft = time.perf_counter() - st
|
||||||
output.ttft = ttft
|
output.ttft = ttft
|
||||||
|
|
||||||
# Decoding phase
|
# Decoding phase
|
||||||
else:
|
else:
|
||||||
output.itl.append(timestamp -
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
most_recent_timestamp)
|
|
||||||
|
|
||||||
generated_text += delta["content"]
|
|
||||||
|
|
||||||
most_recent_timestamp = timestamp
|
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
generated_text += text or ""
|
||||||
|
if usage := data.get("usage"):
|
||||||
|
output.output_tokens = usage.get("completion_tokens")
|
||||||
|
if first_chunk_received:
|
||||||
|
output.success = True
|
||||||
|
else:
|
||||||
|
output.success = False
|
||||||
|
output.error = (
|
||||||
|
"Never received a valid chunk to calculate TTFT."
|
||||||
|
"This response will be marked as failed!"
|
||||||
|
)
|
||||||
output.generated_text = generated_text
|
output.generated_text = generated_text
|
||||||
output.success = True
|
output.latency = most_recent_timestamp - st
|
||||||
output.latency = latency
|
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -370,12 +364,276 @@ async def async_request_openai_chat_completions(
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
|
async def async_request_openai_chat_completions(
|
||||||
# introduced in Python 3.9
|
request_func_input: RequestFuncInput,
|
||||||
def remove_prefix(text: str, prefix: str) -> str:
|
pbar: tqdm | None = None,
|
||||||
if text.startswith(prefix):
|
) -> RequestFuncOutput:
|
||||||
return text[len(prefix):]
|
api_url = request_func_input.api_url
|
||||||
return text
|
assert api_url.endswith(("chat/completions", "profile")), (
|
||||||
|
"OpenAI Chat Completions API URL must end with 'chat/completions'."
|
||||||
|
)
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||||
|
) as session:
|
||||||
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
|
if request_func_input.multi_modal_content:
|
||||||
|
mm_content = request_func_input.multi_modal_content
|
||||||
|
if isinstance(mm_content, list):
|
||||||
|
content.extend(mm_content)
|
||||||
|
elif isinstance(mm_content, dict):
|
||||||
|
content.append(mm_content)
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
"multi_modal_content must be a dict or list[dict] for openai-chat"
|
||||||
|
)
|
||||||
|
payload = {
|
||||||
|
"model": request_func_input.model_name
|
||||||
|
if request_func_input.model_name
|
||||||
|
else request_func_input.model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": content},
|
||||||
|
],
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
"stream_options": {
|
||||||
|
"include_usage": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if request_func_input.ignore_eos:
|
||||||
|
payload["ignore_eos"] = request_func_input.ignore_eos
|
||||||
|
if request_func_input.extra_body:
|
||||||
|
payload.update(request_func_input.extra_body)
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
}
|
||||||
|
if request_func_input.request_id:
|
||||||
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=api_url, json=payload, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||||
|
# NOTE: SSE comments (often used as pings) start with a colon.
|
||||||
|
# These are not JSON data payload and should be skipped.
|
||||||
|
if chunk_bytes.startswith(":"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = chunk_bytes.removeprefix("data: ")
|
||||||
|
|
||||||
|
if chunk != "[DONE]":
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
if choices := data.get("choices"):
|
||||||
|
content = choices[0]["delta"].get("content")
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = timestamp - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(timestamp - most_recent_timestamp)
|
||||||
|
|
||||||
|
generated_text += content or ""
|
||||||
|
elif usage := data.get("usage"):
|
||||||
|
output.output_tokens = usage.get("completion_tokens")
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = most_recent_timestamp - st
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
async def async_request_openai_audio(
|
||||||
|
request_func_input: RequestFuncInput,
|
||||||
|
pbar: tqdm | None = None,
|
||||||
|
) -> RequestFuncOutput:
|
||||||
|
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||||
|
import soundfile
|
||||||
|
|
||||||
|
api_url = request_func_input.api_url
|
||||||
|
assert api_url.endswith(("transcriptions", "translations")), (
|
||||||
|
"OpenAI Chat Completions API URL must end with 'transcriptions' "
|
||||||
|
)
|
||||||
|
"or `translations`."
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession(
|
||||||
|
trust_env=True, timeout=AIOHTTP_TIMEOUT
|
||||||
|
) as session:
|
||||||
|
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||||
|
payload = {
|
||||||
|
"model": request_func_input.model_name
|
||||||
|
if request_func_input.model_name
|
||||||
|
else request_func_input.model,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_completion_tokens": request_func_input.output_len,
|
||||||
|
"stream": True,
|
||||||
|
"language": "en",
|
||||||
|
# Flattened due to multipart/form-data
|
||||||
|
"stream_include_usage": True,
|
||||||
|
"stream_continuous_usage_stats": True,
|
||||||
|
}
|
||||||
|
if request_func_input.extra_body:
|
||||||
|
payload.update(request_func_input.extra_body)
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||||
|
}
|
||||||
|
if request_func_input.request_id:
|
||||||
|
headers["x-request-id"] = request_func_input.request_id
|
||||||
|
|
||||||
|
# Send audio file
|
||||||
|
def to_bytes(y, sr):
|
||||||
|
buffer = io.BytesIO()
|
||||||
|
soundfile.write(buffer, y, sr, format="WAV")
|
||||||
|
buffer.seek(0)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
mm_audio = request_func_input.multi_modal_content
|
||||||
|
if not isinstance(mm_audio, dict) or "audio" not in mm_audio:
|
||||||
|
raise TypeError("multi_modal_content must be a dict containing 'audio'")
|
||||||
|
with to_bytes(*mm_audio["audio"]) as f:
|
||||||
|
form = aiohttp.FormData()
|
||||||
|
form.add_field("file", f, content_type="audio/wav")
|
||||||
|
for key, value in payload.items():
|
||||||
|
form.add_field(key, str(value))
|
||||||
|
|
||||||
|
output = RequestFuncOutput()
|
||||||
|
output.prompt_len = request_func_input.prompt_len
|
||||||
|
|
||||||
|
generated_text = ""
|
||||||
|
ttft = 0.0
|
||||||
|
st = time.perf_counter()
|
||||||
|
most_recent_timestamp = st
|
||||||
|
try:
|
||||||
|
async with session.post(
|
||||||
|
url=api_url, data=form, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
async for chunk_bytes in response.content:
|
||||||
|
chunk_bytes = chunk_bytes.strip()
|
||||||
|
if not chunk_bytes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||||
|
if chunk != "[DONE]":
|
||||||
|
timestamp = time.perf_counter()
|
||||||
|
data = json.loads(chunk)
|
||||||
|
|
||||||
|
if choices := data.get("choices"):
|
||||||
|
content = choices[0]["delta"].get("content")
|
||||||
|
# First token
|
||||||
|
if ttft == 0.0:
|
||||||
|
ttft = timestamp - st
|
||||||
|
output.ttft = ttft
|
||||||
|
|
||||||
|
# Decoding phase
|
||||||
|
else:
|
||||||
|
output.itl.append(
|
||||||
|
timestamp - most_recent_timestamp
|
||||||
|
)
|
||||||
|
|
||||||
|
generated_text += content or ""
|
||||||
|
elif usage := data.get("usage"):
|
||||||
|
output.output_tokens = usage.get(
|
||||||
|
"completion_tokens"
|
||||||
|
)
|
||||||
|
|
||||||
|
most_recent_timestamp = timestamp
|
||||||
|
|
||||||
|
output.generated_text = generated_text
|
||||||
|
output.success = True
|
||||||
|
output.latency = most_recent_timestamp - st
|
||||||
|
else:
|
||||||
|
output.error = response.reason or ""
|
||||||
|
output.success = False
|
||||||
|
except Exception:
|
||||||
|
output.success = False
|
||||||
|
exc_info = sys.exc_info()
|
||||||
|
output.error = "".join(traceback.format_exception(*exc_info))
|
||||||
|
|
||||||
|
if pbar:
|
||||||
|
pbar.update(1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||||
|
if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import get_lock
|
||||||
|
|
||||||
|
# Use file lock to prevent multiple processes from
|
||||||
|
# downloading the same model weights at the same time.
|
||||||
|
with get_lock(pretrained_model_name_or_path):
|
||||||
|
model_path = snapshot_download(
|
||||||
|
model_id=pretrained_model_name_or_path,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"],
|
||||||
|
)
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
return pretrained_model_name_or_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_tokenizer(
|
||||||
|
pretrained_model_name_or_path: str,
|
||||||
|
tokenizer_mode: str = "auto",
|
||||||
|
trust_remote_code: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
|
||||||
|
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||||
|
pretrained_model_name_or_path
|
||||||
|
):
|
||||||
|
pretrained_model_name_or_path = get_model(pretrained_model_name_or_path)
|
||||||
|
if tokenizer_mode == "slow":
|
||||||
|
if kwargs.get("use_fast", False):
|
||||||
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
||||||
|
kwargs["use_fast"] = False
|
||||||
|
if tokenizer_mode == "mistral":
|
||||||
|
try:
|
||||||
|
from vllm.tokenizers.mistral import MistralTokenizer
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"MistralTokenizer requires vllm package.\n"
|
||||||
|
"Please install it with `pip install vllm` "
|
||||||
|
"to use mistral tokenizer mode."
|
||||||
|
) from e
|
||||||
|
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
|
||||||
|
else:
|
||||||
|
return AutoTokenizer.from_pretrained(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ASYNC_REQUEST_FUNCS = {
|
ASYNC_REQUEST_FUNCS = {
|
||||||
@@ -385,5 +643,15 @@ ASYNC_REQUEST_FUNCS = {
|
|||||||
"deepspeed-mii": async_request_deepspeed_mii,
|
"deepspeed-mii": async_request_deepspeed_mii,
|
||||||
"openai": async_request_openai_completions,
|
"openai": async_request_openai_completions,
|
||||||
"openai-chat": async_request_openai_chat_completions,
|
"openai-chat": async_request_openai_chat_completions,
|
||||||
|
"openai-audio": async_request_openai_audio,
|
||||||
"tensorrt-llm": async_request_trt_llm,
|
"tensorrt-llm": async_request_trt_llm,
|
||||||
|
"scalellm": async_request_openai_completions,
|
||||||
|
"sglang": async_request_openai_completions,
|
||||||
|
"llama.cpp": async_request_openai_completions,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OPENAI_COMPATIBLE_BACKENDS = [
|
||||||
|
k
|
||||||
|
for k, v in ASYNC_REQUEST_FUNCS.items()
|
||||||
|
if v in (async_request_openai_completions, async_request_openai_chat_completions)
|
||||||
|
]
|
||||||
|
|||||||
380
benchmarks/benchmark_batch_invariance.py
Executable file
380
benchmarks/benchmark_batch_invariance.py
Executable file
@@ -0,0 +1,380 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark to measure the performance overhead of VLLM_BATCH_INVARIANT mode.
|
||||||
|
|
||||||
|
This benchmark runs the same workload twice:
|
||||||
|
1. With VLLM_BATCH_INVARIANT=0 (baseline)
|
||||||
|
2. With VLLM_BATCH_INVARIANT=1 (batch invariant mode)
|
||||||
|
|
||||||
|
And reports the timing and throughput metrics for comparison.
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
VLLM_BENCH_MODEL: Model to benchmark (default: "Qwen/Qwen3-1.7B")
|
||||||
|
VLLM_BENCH_TP_SIZE: Tensor parallel size (default: 1, use 8 for deepseek)
|
||||||
|
VLLM_BENCH_BATCH_SIZE: Max batch size (default: 128)
|
||||||
|
VLLM_BENCH_NUM_TRIALS: Number of trials to run (default: 5)
|
||||||
|
VLLM_BENCH_MIN_PROMPT: Min prompt length in words (default: 1024)
|
||||||
|
VLLM_BENCH_MAX_PROMPT: Max prompt length in words (default: 2048)
|
||||||
|
VLLM_BENCH_MAX_TOKENS: Max tokens to generate (default: 128)
|
||||||
|
VLLM_BENCH_TEMPERATURE: Temperature for sampling (default: 0.0)
|
||||||
|
VLLM_BENCH_GPU_MEMORY_UTILIZATION: GPU memory utilization (default: 0.4)
|
||||||
|
VLLM_BENCH_MAX_MODEL_LEN: Max model length (default: 5120)
|
||||||
|
VLLM_BENCH_BACKEND: Attention backend (default: FLASH_ATTN)
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# Benchmark qwen3 (default)
|
||||||
|
python benchmarks/benchmark_batch_invariance.py
|
||||||
|
|
||||||
|
# Benchmark deepseek with 8 GPUs
|
||||||
|
VLLM_BENCH_MODEL="deepseek-ai/DeepSeek-V3" VLLM_BENCH_TP_SIZE=8 \\
|
||||||
|
python benchmarks/benchmark_batch_invariance.py
|
||||||
|
|
||||||
|
# Quick test with fewer trials
|
||||||
|
VLLM_BENCH_NUM_TRIALS=2 VLLM_BENCH_BATCH_SIZE=32 \\
|
||||||
|
python benchmarks/benchmark_batch_invariance.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||||
|
"""Generate a random prompt for benchmarking."""
|
||||||
|
prompt_templates = [
|
||||||
|
"Question: What is the capital of France?\nAnswer: The capital of France is",
|
||||||
|
"Q: How does photosynthesis work?\nA: Photosynthesis is the process by which",
|
||||||
|
"User: Can you explain quantum mechanics?\nAssistant: Quantum mechanics is",
|
||||||
|
"Once upon a time in a distant galaxy, there lived",
|
||||||
|
"The old man walked slowly down the street, remembering",
|
||||||
|
"In the year 2157, humanity finally discovered",
|
||||||
|
"To implement a binary search tree in Python, first we need to",
|
||||||
|
"The algorithm works by iterating through the array and",
|
||||||
|
"Here's how to optimize database queries using indexing:",
|
||||||
|
"The Renaissance was a period in European history that",
|
||||||
|
"Climate change is caused by several factors including",
|
||||||
|
"The human brain contains approximately 86 billion neurons which",
|
||||||
|
"I've been thinking about getting a new laptop because",
|
||||||
|
"Yesterday I went to the store and bought",
|
||||||
|
"My favorite thing about summer is definitely",
|
||||||
|
]
|
||||||
|
|
||||||
|
base_prompt = random.choice(prompt_templates)
|
||||||
|
|
||||||
|
if max_words < min_words:
|
||||||
|
max_words = min_words
|
||||||
|
target_words = random.randint(min_words, max_words)
|
||||||
|
|
||||||
|
if target_words > 50:
|
||||||
|
padding_text = (
|
||||||
|
" This is an interesting topic that deserves more explanation. "
|
||||||
|
* (target_words // 50)
|
||||||
|
)
|
||||||
|
base_prompt = base_prompt + padding_text
|
||||||
|
|
||||||
|
return base_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark_with_batch_invariant(
|
||||||
|
model: str,
|
||||||
|
tp_size: int,
|
||||||
|
max_batch_size: int,
|
||||||
|
num_trials: int,
|
||||||
|
min_prompt: int,
|
||||||
|
max_prompt: int,
|
||||||
|
max_tokens: int,
|
||||||
|
temperature: float,
|
||||||
|
gpu_mem_util: float,
|
||||||
|
max_model_len: int,
|
||||||
|
backend: str,
|
||||||
|
batch_invariant: bool,
|
||||||
|
seed: int = 12345,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Run the benchmark with the specified configuration.
|
||||||
|
|
||||||
|
Returns a dict with timing and throughput metrics.
|
||||||
|
"""
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
||||||
|
if batch_invariant:
|
||||||
|
os.environ["VLLM_BATCH_INVARIANT"] = "1"
|
||||||
|
else:
|
||||||
|
os.environ["VLLM_BATCH_INVARIANT"] = "0"
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"BENCHMARK: VLLM_BATCH_INVARIANT={int(batch_invariant)}")
|
||||||
|
print(f" Model: {model}")
|
||||||
|
print(f" TP Size: {tp_size}")
|
||||||
|
print(f" Backend: {backend}")
|
||||||
|
print(f" Max Batch Size: {max_batch_size}")
|
||||||
|
print(f" Trials: {num_trials}")
|
||||||
|
print(f" Max Tokens: {max_tokens}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
sampling = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
seed=20240919,
|
||||||
|
)
|
||||||
|
|
||||||
|
needle_prompt = "There once was a "
|
||||||
|
|
||||||
|
llm = None
|
||||||
|
try:
|
||||||
|
# Create LLM engine
|
||||||
|
start_init = time.perf_counter()
|
||||||
|
llm = LLM(
|
||||||
|
model=model,
|
||||||
|
max_num_seqs=max_batch_size,
|
||||||
|
gpu_memory_utilization=gpu_mem_util,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
dtype="bfloat16",
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
enable_prefix_caching=False,
|
||||||
|
)
|
||||||
|
init_time = time.perf_counter() - start_init
|
||||||
|
print(f"Engine initialization time: {init_time:.2f}s\n")
|
||||||
|
|
||||||
|
# Generate baseline
|
||||||
|
print("Generating baseline (warmup)...")
|
||||||
|
baseline_out = llm.generate([needle_prompt], sampling)
|
||||||
|
assert len(baseline_out) == 1
|
||||||
|
baseline_text = baseline_out[0].outputs[0].text
|
||||||
|
print(f"Baseline output: '{baseline_text[:50]}...'\n")
|
||||||
|
|
||||||
|
# Run trials and measure timing
|
||||||
|
trial_times: list[float] = []
|
||||||
|
total_tokens = 0
|
||||||
|
total_prompts = 0
|
||||||
|
|
||||||
|
for trial in range(num_trials):
|
||||||
|
# Create a batch
|
||||||
|
prompts: list[str] = []
|
||||||
|
batch_size = random.randint(max_batch_size // 2, max_batch_size)
|
||||||
|
needle_pos = random.randint(0, batch_size - 1)
|
||||||
|
for i in range(batch_size):
|
||||||
|
if i == needle_pos:
|
||||||
|
prompts.append(needle_prompt)
|
||||||
|
else:
|
||||||
|
prompts.append(_random_prompt(min_prompt, max_prompt))
|
||||||
|
|
||||||
|
# Measure time for this trial
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
outputs = llm.generate(prompts, sampling)
|
||||||
|
trial_time = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
trial_times.append(trial_time)
|
||||||
|
total_prompts += len(prompts)
|
||||||
|
|
||||||
|
# Count tokens
|
||||||
|
for output in outputs:
|
||||||
|
if output.outputs:
|
||||||
|
total_tokens += len(output.outputs[0].token_ids)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Trial {trial + 1}/{num_trials}: "
|
||||||
|
f"batch_size={batch_size}, "
|
||||||
|
f"time={trial_time:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify needle output still matches
|
||||||
|
needle_output = outputs[needle_pos]
|
||||||
|
assert needle_output.prompt == needle_prompt
|
||||||
|
|
||||||
|
# Compute statistics
|
||||||
|
avg_time = sum(trial_times) / len(trial_times)
|
||||||
|
min_time = min(trial_times)
|
||||||
|
max_time = max(trial_times)
|
||||||
|
throughput = total_tokens / sum(trial_times)
|
||||||
|
prompts_per_sec = total_prompts / sum(trial_times)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("RESULTS:")
|
||||||
|
print(f" Average time per trial: {avg_time:.2f}s")
|
||||||
|
print(f" Min time: {min_time:.2f}s")
|
||||||
|
print(f" Max time: {max_time:.2f}s")
|
||||||
|
print(f" Total tokens generated: {total_tokens}")
|
||||||
|
print(f" Total prompts processed: {total_prompts}")
|
||||||
|
print(f" Throughput: {throughput:.2f} tokens/s")
|
||||||
|
print(f" Prompts/s: {prompts_per_sec:.2f}")
|
||||||
|
print(f"{'=' * 80}\n")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"init_time": init_time,
|
||||||
|
"avg_time": avg_time,
|
||||||
|
"min_time": min_time,
|
||||||
|
"max_time": max_time,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"total_prompts": total_prompts,
|
||||||
|
"throughput": throughput,
|
||||||
|
"prompts_per_sec": prompts_per_sec,
|
||||||
|
"trial_times": trial_times,
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
if llm is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
llm.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Check platform support
|
||||||
|
if not (current_platform.is_cuda() and current_platform.has_device_capability(90)):
|
||||||
|
print("ERROR: Requires CUDA and >= Hopper (SM90)")
|
||||||
|
print(f"Current platform: {current_platform.device_type}")
|
||||||
|
if current_platform.is_cuda():
|
||||||
|
print(f"Device capability: {current_platform.get_device_capability()}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# Read configuration from environment
|
||||||
|
model = os.getenv("VLLM_BENCH_MODEL", "Qwen/Qwen3-1.7B")
|
||||||
|
tp_size = int(os.getenv("VLLM_BENCH_TP_SIZE", "1"))
|
||||||
|
max_batch_size = int(os.getenv("VLLM_BENCH_BATCH_SIZE", "128"))
|
||||||
|
num_trials = int(os.getenv("VLLM_BENCH_NUM_TRIALS", "5"))
|
||||||
|
min_prompt = int(os.getenv("VLLM_BENCH_MIN_PROMPT", "1024"))
|
||||||
|
max_prompt = int(os.getenv("VLLM_BENCH_MAX_PROMPT", "2048"))
|
||||||
|
max_tokens = int(os.getenv("VLLM_BENCH_MAX_TOKENS", "128"))
|
||||||
|
temperature = float(os.getenv("VLLM_BENCH_TEMPERATURE", "0.0"))
|
||||||
|
gpu_mem_util = float(os.getenv("VLLM_BENCH_GPU_MEMORY_UTILIZATION", "0.4"))
|
||||||
|
max_model_len = int(os.getenv("VLLM_BENCH_MAX_MODEL_LEN", "5120"))
|
||||||
|
backend = os.getenv("VLLM_BENCH_BACKEND", "FLASH_ATTN")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("VLLM BATCH INVARIANCE BENCHMARK")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\nConfiguration:")
|
||||||
|
print(f" Model: {model}")
|
||||||
|
print(f" Tensor Parallel Size: {tp_size}")
|
||||||
|
print(f" Attention Backend: {backend}")
|
||||||
|
print(f" Max Batch Size: {max_batch_size}")
|
||||||
|
print(f" Number of Trials: {num_trials}")
|
||||||
|
print(f" Prompt Length Range: {min_prompt}-{max_prompt} words")
|
||||||
|
print(f" Max Tokens to Generate: {max_tokens}")
|
||||||
|
print(f" Temperature: {temperature}")
|
||||||
|
print(f" GPU Memory Utilization: {gpu_mem_util}")
|
||||||
|
print(f" Max Model Length: {max_model_len}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Run benchmark WITHOUT batch invariance (baseline)
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("PHASE 1: Running WITHOUT batch invariance (baseline)")
|
||||||
|
print("=" * 80)
|
||||||
|
baseline_results = run_benchmark_with_batch_invariant(
|
||||||
|
model=model,
|
||||||
|
tp_size=tp_size,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
num_trials=num_trials,
|
||||||
|
min_prompt=min_prompt,
|
||||||
|
max_prompt=max_prompt,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
gpu_mem_util=gpu_mem_util,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
backend=backend,
|
||||||
|
batch_invariant=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run benchmark WITH batch invariance
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("PHASE 2: Running WITH batch invariance")
|
||||||
|
print("=" * 80)
|
||||||
|
batch_inv_results = run_benchmark_with_batch_invariant(
|
||||||
|
model=model,
|
||||||
|
tp_size=tp_size,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
num_trials=num_trials,
|
||||||
|
min_prompt=min_prompt,
|
||||||
|
max_prompt=max_prompt,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
gpu_mem_util=gpu_mem_util,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
backend=backend,
|
||||||
|
batch_invariant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("COMPARISON: Batch Invariance vs Baseline")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
init_overhead_pct = (
|
||||||
|
(batch_inv_results["init_time"] - baseline_results["init_time"])
|
||||||
|
/ baseline_results["init_time"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
time_overhead_pct = (
|
||||||
|
(batch_inv_results["avg_time"] - baseline_results["avg_time"])
|
||||||
|
/ baseline_results["avg_time"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
throughput_change_pct = (
|
||||||
|
(batch_inv_results["throughput"] - baseline_results["throughput"])
|
||||||
|
/ baseline_results["throughput"]
|
||||||
|
* 100
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nInitialization Time:")
|
||||||
|
print(f" Baseline: {baseline_results['init_time']:.2f}s")
|
||||||
|
print(f" Batch Invariant: {batch_inv_results['init_time']:.2f}s")
|
||||||
|
print(f" Overhead: {init_overhead_pct:+.2f}%")
|
||||||
|
|
||||||
|
print("\nAverage Trial Time:")
|
||||||
|
print(f" Baseline: {baseline_results['avg_time']:.2f}s")
|
||||||
|
print(f" Batch Invariant: {batch_inv_results['avg_time']:.2f}s")
|
||||||
|
print(f" Overhead: {time_overhead_pct:+.2f}%")
|
||||||
|
|
||||||
|
print("\nThroughput (tokens/s):")
|
||||||
|
print(f" Baseline: {baseline_results['throughput']:.2f}")
|
||||||
|
print(f" Batch Invariant: {batch_inv_results['throughput']:.2f}")
|
||||||
|
print(f" Change: {throughput_change_pct:+.2f}%")
|
||||||
|
|
||||||
|
print("\nPrompts/s:")
|
||||||
|
print(f" Baseline: {baseline_results['prompts_per_sec']:.2f}")
|
||||||
|
print(f" Batch Invariant: {batch_inv_results['prompts_per_sec']:.2f}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 80)
|
||||||
|
if time_overhead_pct > 0:
|
||||||
|
print(
|
||||||
|
f"Batch invariance mode adds approximately {time_overhead_pct:.1f}% "
|
||||||
|
"overhead"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Batch invariance mode is approximately {-time_overhead_pct:.1f}% "
|
||||||
|
"faster (unexpected!)"
|
||||||
|
)
|
||||||
|
|
||||||
|
if abs(throughput_change_pct) < 1.0:
|
||||||
|
print("Throughput difference is negligible (< 1%)")
|
||||||
|
elif throughput_change_pct < 0:
|
||||||
|
print(
|
||||||
|
f"Throughput decreased by {-throughput_change_pct:.1f}% "
|
||||||
|
"with batch invariance"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
f"Throughput increased by {throughput_change_pct:.1f}% "
|
||||||
|
"with batch invariance (unexpected!)"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 80 + "\n")
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
||||||
74
benchmarks/benchmark_block_pool.py
Normal file
74
benchmarks/benchmark_block_pool.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import gc
|
||||||
|
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
rows = []
|
||||||
|
for allocate_block in args.allocate_blocks:
|
||||||
|
# Enforce a GC collect ahead to minimize the impact among runs
|
||||||
|
gc.collect()
|
||||||
|
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||||
|
|
||||||
|
get_blocks_times = TimeCollector(TimeCollector.US)
|
||||||
|
free_blocks_times = TimeCollector(TimeCollector.US)
|
||||||
|
for _ in range(args.num_iteration):
|
||||||
|
with get_blocks_times:
|
||||||
|
blocks = block_pool.get_new_blocks(allocate_block)
|
||||||
|
with free_blocks_times:
|
||||||
|
block_pool.free_blocks(blocks)
|
||||||
|
|
||||||
|
rows.append(
|
||||||
|
[get_blocks_times.cnt, args.num_gpu_blocks, allocate_block]
|
||||||
|
+ get_blocks_times.dump_avg_max()
|
||||||
|
+ free_blocks_times.dump_avg_max()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
tabulate(
|
||||||
|
rows,
|
||||||
|
headers=[
|
||||||
|
"Iterations",
|
||||||
|
"Total\nBlocks",
|
||||||
|
"Allocated\nBlocks",
|
||||||
|
"Get Blocks\nAvg (us)",
|
||||||
|
"Get Blocks\nMax (us)",
|
||||||
|
"Free Blocks\nAvg (us)",
|
||||||
|
"Free Blocks\nMax (us)",
|
||||||
|
],
|
||||||
|
tablefmt="grid",
|
||||||
|
floatfmt=".3f",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_main() -> None:
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance of BlockPool for KV Cache."
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-iteration",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
|
help="Number of iterations to run to stabilize final data readings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allocate-blocks",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
default=[10, 50, 100, 500, 1000],
|
||||||
|
help="Number of blocks to allocate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
invoke_main() # pragma: no cover
|
||||||
120
benchmarks/benchmark_hash.py
Normal file
120
benchmarks/benchmark_hash.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Micro benchmark comparing built-in hash(), SHA-256, and xxHash.
|
||||||
|
|
||||||
|
This focuses on a single test payload shaped like the prefix-cache hash input:
|
||||||
|
(32-byte bytes object, 32-int tuple)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python benchmarks/hash_micro_benchmark.py --iterations 20000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import statistics
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
|
from vllm.utils.hashing import sha256, xxhash
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_test_data(seed: int) -> tuple[bytes, tuple[int, ...]]:
|
||||||
|
"""Generate a deterministic test payload."""
|
||||||
|
random.seed(seed)
|
||||||
|
bytes_data = bytes(random.getrandbits(8) for _ in range(32))
|
||||||
|
int_tuple = tuple(random.randint(1, 1_000_000) for _ in range(32))
|
||||||
|
return (bytes_data, int_tuple)
|
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_func(func: Callable[[tuple], object], data: tuple, iterations: int):
|
||||||
|
"""Return (avg_seconds, std_seconds) for hashing `data` `iterations` times."""
|
||||||
|
times: list[float] = []
|
||||||
|
|
||||||
|
# Warm-up to avoid first-run noise.
|
||||||
|
for _ in range(200):
|
||||||
|
func(data)
|
||||||
|
|
||||||
|
for _ in range(iterations):
|
||||||
|
start = time.perf_counter()
|
||||||
|
func(data)
|
||||||
|
end = time.perf_counter()
|
||||||
|
times.append(end - start)
|
||||||
|
|
||||||
|
avg = statistics.mean(times)
|
||||||
|
std = statistics.stdev(times) if len(times) > 1 else 0.0
|
||||||
|
return avg, std
|
||||||
|
|
||||||
|
|
||||||
|
def _run_benchmarks(
|
||||||
|
benchmarks: Iterable[tuple[str, Callable[[tuple], object]]],
|
||||||
|
data: tuple,
|
||||||
|
iterations: int,
|
||||||
|
):
|
||||||
|
"""Yield (name, avg, std) for each benchmark, skipping unavailable ones."""
|
||||||
|
for name, func in benchmarks:
|
||||||
|
try:
|
||||||
|
avg, std = _benchmark_func(func, data, iterations)
|
||||||
|
except ModuleNotFoundError as exc:
|
||||||
|
print(f"Skipping {name}: {exc}")
|
||||||
|
continue
|
||||||
|
yield name, avg, std
|
||||||
|
|
||||||
|
|
||||||
|
def builtin_hash(data: tuple) -> int:
|
||||||
|
"""Wrapper for Python's built-in hash()."""
|
||||||
|
return hash(data)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument(
|
||||||
|
"--iterations",
|
||||||
|
type=int,
|
||||||
|
default=10_000,
|
||||||
|
help="Number of measured iterations per hash function.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=42, help="Random seed for test payload."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
data = _generate_test_data(args.seed)
|
||||||
|
benchmarks = (
|
||||||
|
("SHA256 (pickle)", sha256),
|
||||||
|
("xxHash (pickle)", xxhash),
|
||||||
|
("built-in hash()", builtin_hash),
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("HASH FUNCTION MICRO BENCHMARK")
|
||||||
|
print("=" * 60)
|
||||||
|
print("Test data: (32-byte bytes object, 32-int tuple)")
|
||||||
|
print(f"Iterations: {args.iterations:,}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
results = list(_run_benchmarks(benchmarks, data, args.iterations))
|
||||||
|
builtin_entry = next((r for r in results if r[0] == "built-in hash()"), None)
|
||||||
|
|
||||||
|
print("\nResults:")
|
||||||
|
for name, avg, std in results:
|
||||||
|
print(f" {name:16s}: {avg * 1e6:8.2f} ± {std * 1e6:6.2f} μs")
|
||||||
|
|
||||||
|
if builtin_entry:
|
||||||
|
_, builtin_avg, _ = builtin_entry
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SUMMARY (relative to built-in hash())")
|
||||||
|
print("=" * 60)
|
||||||
|
for name, avg, _ in results:
|
||||||
|
if name == "built-in hash()":
|
||||||
|
continue
|
||||||
|
speed_ratio = avg / builtin_avg
|
||||||
|
print(f"• {name} is {speed_ratio:.1f}x slower than built-in hash()")
|
||||||
|
else:
|
||||||
|
print("\nBuilt-in hash() result missing; cannot compute speed ratios.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,195 +1,17 @@
|
|||||||
"""Benchmark the latency of processing a single batch of requests."""
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import argparse
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import time
|
import sys
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
if __name__ == "__main__":
|
||||||
import torch
|
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
Please use the following command instead:
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
vllm bench latency
|
||||||
|
|
||||||
|
For help with the new command, run:
|
||||||
|
vllm bench latency --help
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
Alternatively, you can run the new command directly with:
|
||||||
print(args)
|
python -m vllm.entrypoints.cli.main bench latency --help
|
||||||
|
""")
|
||||||
# NOTE(woosuk): If the request cannot be processed in a single batch,
|
sys.exit(1)
|
||||||
# the engine will automatically process the request in multiple batches.
|
|
||||||
llm = LLM(model=args.model,
|
|
||||||
tokenizer=args.tokenizer,
|
|
||||||
quantization=args.quantization,
|
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
|
||||||
trust_remote_code=args.trust_remote_code,
|
|
||||||
dtype=args.dtype,
|
|
||||||
enforce_eager=args.enforce_eager,
|
|
||||||
kv_cache_dtype=args.kv_cache_dtype,
|
|
||||||
quantization_param_path=args.quantization_param_path,
|
|
||||||
device=args.device,
|
|
||||||
ray_workers_use_nsight=args.ray_workers_use_nsight,
|
|
||||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
|
||||||
download_dir=args.download_dir,
|
|
||||||
block_size=args.block_size)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
n=args.n,
|
|
||||||
temperature=0.0 if args.use_beam_search else 1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
use_beam_search=args.use_beam_search,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=args.output_len,
|
|
||||||
)
|
|
||||||
print(sampling_params)
|
|
||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
|
||||||
size=(args.batch_size,
|
|
||||||
args.input_len))
|
|
||||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
|
||||||
|
|
||||||
def run_to_completion(profile_dir: Optional[str] = None):
|
|
||||||
if profile_dir:
|
|
||||||
with torch.profiler.profile(
|
|
||||||
activities=[
|
|
||||||
torch.profiler.ProfilerActivity.CPU,
|
|
||||||
torch.profiler.ProfilerActivity.CUDA,
|
|
||||||
],
|
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
|
||||||
str(profile_dir))) as p:
|
|
||||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=False)
|
|
||||||
print(p.key_averages())
|
|
||||||
else:
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
use_tqdm=False)
|
|
||||||
end_time = time.perf_counter()
|
|
||||||
latency = end_time - start_time
|
|
||||||
return latency
|
|
||||||
|
|
||||||
print("Warming up...")
|
|
||||||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
|
|
||||||
run_to_completion(profile_dir=None)
|
|
||||||
|
|
||||||
if args.profile:
|
|
||||||
profile_dir = args.profile_result_dir
|
|
||||||
if not profile_dir:
|
|
||||||
profile_dir = Path(
|
|
||||||
"."
|
|
||||||
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
|
||||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
|
||||||
run_to_completion(profile_dir=profile_dir)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Benchmark.
|
|
||||||
latencies = []
|
|
||||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
|
||||||
latencies.append(run_to_completion(profile_dir=None))
|
|
||||||
latencies = np.array(latencies)
|
|
||||||
percentages = [10, 25, 50, 75, 90]
|
|
||||||
percentiles = np.percentile(latencies, percentages)
|
|
||||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
|
||||||
for percentage, percentile in zip(percentages, percentiles):
|
|
||||||
print(f'{percentage}% percentile latency: {percentile} seconds')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description='Benchmark the latency of processing a single batch of '
|
|
||||||
'requests till completion.')
|
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
|
||||||
parser.add_argument('--quantization',
|
|
||||||
'-q',
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=None)
|
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
|
||||||
parser.add_argument('--n',
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help='Number of generated sequences per prompt.')
|
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
|
||||||
parser.add_argument('--num-iters-warmup',
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help='Number of iterations to run for warmup.')
|
|
||||||
parser.add_argument('--num-iters',
|
|
||||||
type=int,
|
|
||||||
default=30,
|
|
||||||
help='Number of iterations to run.')
|
|
||||||
parser.add_argument('--trust-remote-code',
|
|
||||||
action='store_true',
|
|
||||||
help='trust remote code from huggingface')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
||||||
help='data type for model weights and activations. '
|
|
||||||
'The "auto" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.')
|
|
||||||
parser.add_argument('--enforce-eager',
|
|
||||||
action='store_true',
|
|
||||||
help='enforce eager mode and disable CUDA graph')
|
|
||||||
parser.add_argument(
|
|
||||||
"--kv-cache-dtype",
|
|
||||||
type=str,
|
|
||||||
choices=['auto', 'fp8'],
|
|
||||||
default='auto',
|
|
||||||
help=
|
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--quantization-param-path',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
|
||||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
|
||||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
|
||||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
|
||||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
||||||
'instead supported for common inference criteria.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--profile',
|
|
||||||
action='store_true',
|
|
||||||
help='profile the generation process of a single batch')
|
|
||||||
parser.add_argument(
|
|
||||||
'--profile-result-dir',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help=('path to save the pytorch profiler output. Can be visualized '
|
|
||||||
'with ui.perfetto.dev or Tensorboard.'))
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
choices=["cuda", "cpu"],
|
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
|
||||||
parser.add_argument('--block-size',
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help='block size of key/value cache')
|
|
||||||
parser.add_argument(
|
|
||||||
'--enable-chunked-prefill',
|
|
||||||
action='store_true',
|
|
||||||
help='If True, the prefill requests can be chunked based on the '
|
|
||||||
'max_num_batched_tokens')
|
|
||||||
parser.add_argument(
|
|
||||||
"--ray-workers-use-nsight",
|
|
||||||
action='store_true',
|
|
||||||
help="If specified, use nsight to profile ray workers",
|
|
||||||
)
|
|
||||||
parser.add_argument('--download-dir',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='directory to download and load the weights, '
|
|
||||||
'default to the default cache dir of huggingface')
|
|
||||||
args = parser.parse_args()
|
|
||||||
main(args)
|
|
||||||
|
|||||||
202
benchmarks/benchmark_long_document_qa_throughput.py
Normal file
202
benchmarks/benchmark_long_document_qa_throughput.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Offline benchmark to test the long document QA throughput.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
# This workload samples 8 different prompts with a default input
|
||||||
|
# length of 20000 tokens, then replicates each prompt 2 times
|
||||||
|
# in random order.
|
||||||
|
python benchmark_long_document_qa_throughput.py \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--num-documents 8 \
|
||||||
|
--repeat-count 2
|
||||||
|
|
||||||
|
Commandline arguments:
|
||||||
|
--num-documents: The number of documents to sample prompts from.
|
||||||
|
|
||||||
|
--document-length: The length of each document in tokens.
|
||||||
|
(Optional, default: 20000)
|
||||||
|
|
||||||
|
--output-len: The number of tokens to generate for each prompt.
|
||||||
|
(Optional, default: 10)
|
||||||
|
|
||||||
|
--repeat-count: The number of times to repeat each prompt.
|
||||||
|
(Optional, default: 2)
|
||||||
|
|
||||||
|
--repeat-mode: The mode to repeat prompts. The supported modes are:
|
||||||
|
- 'random': shuffle the prompts randomly. (Default)
|
||||||
|
- 'tile': the entire prompt list is repeated in sequence. (Potentially
|
||||||
|
lowest cache hit)
|
||||||
|
- 'interleave': each prompt is repeated consecutively before
|
||||||
|
moving to the next element. (Highest cache hit)
|
||||||
|
|
||||||
|
--shuffle-seed: Random seed when the repeat mode is "random".
|
||||||
|
(Optional, default: 0)
|
||||||
|
|
||||||
|
In the meantime, it also supports all the vLLM engine args to initialize the
|
||||||
|
LLM engine. You can refer to the `vllm.engine.arg_utils.EngineArgs` for more
|
||||||
|
details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def test_long_document_qa(llm=None, sampling_params=None, prompts=None):
|
||||||
|
"""
|
||||||
|
Test long document QA with the given prompts and sampling parameters.
|
||||||
|
Print the time spent in processing all the prompts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm: The language model used for generating responses.
|
||||||
|
sampling_params: Sampling parameter used to generate the response.
|
||||||
|
prompts: A list of prompt strings to be processed by the LLM.
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
llm.generate(prompts, sampling_params=sampling_params)
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Time to execute all requests: {end_time - start_time:.4f} secs")
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_prompts(prompts, repeat_count, mode: str):
|
||||||
|
"""
|
||||||
|
Repeat each prompt in the list for a specified number of times.
|
||||||
|
The order of prompts in the output list depends on the mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: A list of prompts to be repeated.
|
||||||
|
repeat_count: The number of times each prompt is repeated.
|
||||||
|
mode: The mode of repetition. Supported modes are:
|
||||||
|
- 'random': Shuffle the prompts randomly after repetition.
|
||||||
|
- 'tile': Repeat the entire prompt list in sequence.
|
||||||
|
Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
||||||
|
- 'interleave': Repeat each prompt consecutively before moving to
|
||||||
|
the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of repeated prompts in the specified order.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If an invalid mode is provided.
|
||||||
|
"""
|
||||||
|
print("Repeat mode: ", mode)
|
||||||
|
if mode == "random":
|
||||||
|
repeated_prompts = prompts * repeat_count
|
||||||
|
random.shuffle(repeated_prompts)
|
||||||
|
return repeated_prompts
|
||||||
|
elif mode == "tile":
|
||||||
|
return prompts * repeat_count
|
||||||
|
elif mode == "interleave":
|
||||||
|
repeated_prompts = []
|
||||||
|
for prompt in prompts:
|
||||||
|
repeated_prompts.extend([prompt] * repeat_count)
|
||||||
|
return repeated_prompts
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
random.seed(args.shuffle_seed)
|
||||||
|
|
||||||
|
# Prepare the prompts:
|
||||||
|
# we append the document id at the beginning to avoid any of the document
|
||||||
|
# being the prefix of other documents
|
||||||
|
prompts = [
|
||||||
|
str(i) + " ".join(["hi"] * args.document_length)
|
||||||
|
for i in range(args.num_documents)
|
||||||
|
]
|
||||||
|
|
||||||
|
prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode)
|
||||||
|
|
||||||
|
warmup_prompts = [
|
||||||
|
"This is warm up request " + str(i) + " ".join(["hi"] * args.document_length)
|
||||||
|
for i in range(args.num_documents)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create the LLM engine
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
|
||||||
|
print("------warm up------")
|
||||||
|
test_long_document_qa(
|
||||||
|
llm=llm,
|
||||||
|
prompts=warmup_prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("------start generating------")
|
||||||
|
test_long_document_qa(
|
||||||
|
llm=llm,
|
||||||
|
prompts=prompts,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_argument_parser():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance with or "
|
||||||
|
"without automatic prefix caching."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--document-length",
|
||||||
|
type=int,
|
||||||
|
# Roughly the number of tokens for a system paper,
|
||||||
|
# excluding images
|
||||||
|
default=20000,
|
||||||
|
help="Range of input lengths for sampling prompts, "
|
||||||
|
'specified as "min:max" (e.g., "128:256").',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-documents",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Range of input lengths for sampling prompts, "
|
||||||
|
'specified as "min:max" (e.g., "128:256").',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--output-len", type=int, default=10)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat-count",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of times to repeat each prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat-mode",
|
||||||
|
type=str,
|
||||||
|
default="random",
|
||||||
|
help="The mode to repeat prompts. The supported "
|
||||||
|
'modes are "random", "tile", and "interleave". '
|
||||||
|
"See repeat_prompts() in the source code for details.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--shuffle-seed",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help='Random seed when the repeat mode is "random"',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = create_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
215
benchmarks/benchmark_ngram_proposer.py
Normal file
215
benchmarks/benchmark_ngram_proposer.py
Normal file
@@ -0,0 +1,215 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import gc
|
||||||
|
import time
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from benchmark_utils import TimeCollector
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from vllm.config import (
|
||||||
|
CacheConfig,
|
||||||
|
DeviceConfig,
|
||||||
|
LoadConfig,
|
||||||
|
ModelConfig,
|
||||||
|
ParallelConfig,
|
||||||
|
SchedulerConfig,
|
||||||
|
SpeculativeConfig,
|
||||||
|
VllmConfig,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
|
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||||
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_propose(args):
|
||||||
|
rows = []
|
||||||
|
for max_ngram in args.max_ngram:
|
||||||
|
collector = TimeCollector(TimeCollector.US)
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
max_model_len=args.num_token + args.num_spec_token,
|
||||||
|
tokenizer="facebook/opt-125m",
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
dtype="auto",
|
||||||
|
seed=0,
|
||||||
|
trust_remote_code=False,
|
||||||
|
)
|
||||||
|
proposer = NgramProposer(
|
||||||
|
vllm_config=VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
speculative_config=SpeculativeConfig(
|
||||||
|
prompt_lookup_min=args.min_ngram,
|
||||||
|
prompt_lookup_max=max_ngram,
|
||||||
|
num_speculative_tokens=args.num_spec_token,
|
||||||
|
method="ngram",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
proposer.propose(np.random.randint(0, 20, (args.num_token,)))
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
for _ in range(args.num_iteration):
|
||||||
|
tokens = np.random.randint(0, 20, (args.num_req, args.num_token))
|
||||||
|
with collector:
|
||||||
|
for i in range(args.num_req):
|
||||||
|
proposer.propose(tokens[i, :])
|
||||||
|
rows.append(
|
||||||
|
[args.num_req, args.num_token, args.min_ngram, max_ngram]
|
||||||
|
+ collector.dump_avg_max()
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
tabulate(
|
||||||
|
rows,
|
||||||
|
headers=[
|
||||||
|
"# Request",
|
||||||
|
"# Token",
|
||||||
|
"Min Ngram",
|
||||||
|
"Max Ngram",
|
||||||
|
"Avg (us)",
|
||||||
|
"Max (us)",
|
||||||
|
],
|
||||||
|
tablefmt="grid",
|
||||||
|
floatfmt=".3f",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_batched_propose(args):
|
||||||
|
NUM_SPECULATIVE_TOKENS_NGRAM = 10
|
||||||
|
PROMPT_LOOKUP_MIN = 5
|
||||||
|
PROMPT_LOOKUP_MAX = 15
|
||||||
|
MAX_MODEL_LEN = int(1e7)
|
||||||
|
DEVICE = current_platform.device_type
|
||||||
|
|
||||||
|
model_config = ModelConfig(model="facebook/opt-125m", runner="generate")
|
||||||
|
|
||||||
|
speculative_config = SpeculativeConfig(
|
||||||
|
target_model_config=model_config,
|
||||||
|
target_parallel_config=ParallelConfig(),
|
||||||
|
method="ngram",
|
||||||
|
num_speculative_tokens=NUM_SPECULATIVE_TOKENS_NGRAM,
|
||||||
|
prompt_lookup_max=PROMPT_LOOKUP_MAX,
|
||||||
|
prompt_lookup_min=PROMPT_LOOKUP_MIN,
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=CacheConfig(),
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
device_config=DeviceConfig(device=current_platform.device_type),
|
||||||
|
parallel_config=ParallelConfig(),
|
||||||
|
load_config=LoadConfig(),
|
||||||
|
scheduler_config=SchedulerConfig(
|
||||||
|
max_model_len=model_config.max_model_len,
|
||||||
|
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group
|
||||||
|
mock_pp_group = mock.MagicMock()
|
||||||
|
mock_pp_group.world_size = 1
|
||||||
|
with mock.patch(
|
||||||
|
"vllm.v1.worker.gpu_model_runner.get_pp_group", return_value=mock_pp_group
|
||||||
|
):
|
||||||
|
runner = GPUModelRunner(vllm_config, DEVICE)
|
||||||
|
|
||||||
|
# hack max model len
|
||||||
|
runner.max_model_len = MAX_MODEL_LEN
|
||||||
|
runner.drafter.max_model_len = MAX_MODEL_LEN
|
||||||
|
|
||||||
|
dummy_input_batch = InputBatch(
|
||||||
|
max_num_reqs=args.num_req,
|
||||||
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
max_num_batched_tokens=args.num_req * args.num_token,
|
||||||
|
device=DEVICE,
|
||||||
|
pin_memory=False,
|
||||||
|
vocab_size=256000,
|
||||||
|
block_sizes=[16],
|
||||||
|
)
|
||||||
|
dummy_input_batch._req_ids = list(str(id) for id in range(args.num_req))
|
||||||
|
dummy_input_batch.spec_decode_unsupported_reqs = ()
|
||||||
|
dummy_input_batch.num_tokens_no_spec = [args.num_token] * args.num_req
|
||||||
|
dummy_input_batch.token_ids_cpu = np.random.randint(
|
||||||
|
0, 20, (args.num_req, args.num_token)
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.input_batch = dummy_input_batch
|
||||||
|
|
||||||
|
sampled_token_ids = [[0]] * args.num_req
|
||||||
|
|
||||||
|
print("Starting benchmark")
|
||||||
|
# first run is warmup so ignore it
|
||||||
|
for _ in range(args.num_iteration):
|
||||||
|
start = time.time()
|
||||||
|
runner.drafter.propose(
|
||||||
|
sampled_token_ids,
|
||||||
|
dummy_input_batch.req_ids,
|
||||||
|
dummy_input_batch.num_tokens_no_spec,
|
||||||
|
dummy_input_batch.token_ids_cpu,
|
||||||
|
dummy_input_batch.spec_decode_unsupported_reqs,
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Iteration time (s): {end - start}")
|
||||||
|
|
||||||
|
|
||||||
|
def invoke_main() -> None:
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance of N-gram speculative decode drafting"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batched", action="store_true", help="consider time to prepare batch"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-iteration",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of iterations to run to stabilize final data readings",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-req", type=int, default=128, help="Number of requests in the batch"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-token", type=int, default=1500, help="Number of tokens for each request"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-ngram",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Minimum n-gram to match",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-ngram",
|
||||||
|
type=int,
|
||||||
|
nargs="*",
|
||||||
|
default=[5, 7, 10, 15, 20],
|
||||||
|
help="Maximum n-gram to match",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-spec-token",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="Number of speculative tokens to generate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not args.batched:
|
||||||
|
benchmark_propose(args)
|
||||||
|
else:
|
||||||
|
benchmark_batched_propose(args)
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Example command lines:
|
||||||
|
# time python3 benchmarks/benchmark_ngram_proposer.py
|
||||||
|
# time python3 benchmarks/benchmark_ngram_proposer.py --batched --num-iteration 4 --num-token 1000000 --num-req 128
|
||||||
|
""" # noqa: E501
|
||||||
|
if __name__ == "__main__":
|
||||||
|
invoke_main() # pragma: no cover
|
||||||
110
benchmarks/benchmark_prefix_block_hash.py
Normal file
110
benchmarks/benchmark_prefix_block_hash.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Simple benchmark to compare prefix-cache block hashing algorithms.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python benchmark_prefix_block_hash.py --num-blocks 20000 --block-size 32
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import random
|
||||||
|
import statistics
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Iterable, Sequence
|
||||||
|
|
||||||
|
from vllm.utils.hashing import get_hash_fn_by_name
|
||||||
|
from vllm.v1.core.kv_cache_utils import BlockHash, hash_block_tokens, init_none_hash
|
||||||
|
|
||||||
|
SUPPORTED_ALGOS = ("sha256", "sha256_cbor", "xxhash", "xxhash_cbor")
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_blocks(
|
||||||
|
num_blocks: int, block_size: int, vocab_size: int, seed: int
|
||||||
|
) -> list[list[int]]:
|
||||||
|
rng = random.Random(seed)
|
||||||
|
return [
|
||||||
|
[rng.randrange(vocab_size) for _ in range(block_size)]
|
||||||
|
for _ in range(num_blocks)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_all_blocks(
|
||||||
|
hash_fn: Callable[[object], bytes],
|
||||||
|
blocks: Iterable[Sequence[int]],
|
||||||
|
) -> float:
|
||||||
|
parent_hash: BlockHash | None = None
|
||||||
|
start = time.perf_counter()
|
||||||
|
for block in blocks:
|
||||||
|
parent_hash = hash_block_tokens(hash_fn, parent_hash, block, extra_keys=None)
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
def _benchmark(
|
||||||
|
hash_algo: str,
|
||||||
|
blocks: list[list[int]],
|
||||||
|
trials: int,
|
||||||
|
) -> tuple[float, float, float] | None:
|
||||||
|
try:
|
||||||
|
hash_fn = get_hash_fn_by_name(hash_algo)
|
||||||
|
init_none_hash(hash_fn)
|
||||||
|
timings = [_hash_all_blocks(hash_fn, blocks) for _ in range(trials)]
|
||||||
|
except ModuleNotFoundError as exc:
|
||||||
|
print(f"Skipping {hash_algo}: {exc}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
avg = statistics.mean(timings)
|
||||||
|
best = min(timings)
|
||||||
|
# throughput: tokens / second
|
||||||
|
tokens_hashed = len(blocks) * len(blocks[0])
|
||||||
|
throughput = tokens_hashed / best
|
||||||
|
return avg, best, throughput
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||||||
|
parser.add_argument("--num-blocks", type=int, default=10000, help="Block count.")
|
||||||
|
parser.add_argument("--block-size", type=int, default=32, help="Tokens per block.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vocab-size", type=int, default=32000, help="Token id range [0, vocab_size)."
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--trials", type=int, default=5, help="Number of timed trials per algorithm."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--algorithms",
|
||||||
|
nargs="+",
|
||||||
|
default=SUPPORTED_ALGOS,
|
||||||
|
choices=SUPPORTED_ALGOS,
|
||||||
|
help="Hash algorithms to benchmark.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
blocks = _generate_blocks(
|
||||||
|
args.num_blocks, args.block_size, args.vocab_size, args.seed
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Benchmarking {len(args.algorithms)} algorithms on "
|
||||||
|
f"{args.num_blocks} blocks (block size={args.block_size})."
|
||||||
|
)
|
||||||
|
|
||||||
|
for algo in args.algorithms:
|
||||||
|
result = _benchmark(algo, blocks, args.trials)
|
||||||
|
if result is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
avg, best, throughput = result
|
||||||
|
print(
|
||||||
|
f"{algo:14s} avg: {avg:.6f}s best: {best:.6f}s "
|
||||||
|
f"throughput: {throughput / 1e6:.2f}M tokens/s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,7 +1,48 @@
|
|||||||
import argparse
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark the efficiency of prefix caching.
|
||||||
|
|
||||||
|
This script allows you to benchmark the performance of
|
||||||
|
a model with and without prefix caching using either fixed prompts
|
||||||
|
or prompts sampled from the ShareGPT dataset.
|
||||||
|
|
||||||
|
Fixed example usage:
|
||||||
|
python benchmark_prefix_caching.py \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--num-prompts 1 \
|
||||||
|
--repeat-count 100 \
|
||||||
|
--input-length-range 128:256
|
||||||
|
|
||||||
|
ShareGPT example usage:
|
||||||
|
# This command samples 20 prompts with input lengths
|
||||||
|
# between 128 and 256 tokens from the ShareGPT dataset,
|
||||||
|
# then replicates each prompt 5 times.
|
||||||
|
python benchmark_prefix_caching.py \
|
||||||
|
--model meta-llama/Llama-2-7b-chat-hf \
|
||||||
|
--dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||||
|
--enable-prefix-caching \
|
||||||
|
--num-prompts 20 \
|
||||||
|
--repeat-count 5 \
|
||||||
|
--input-length-range 128:256
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.tokenizers import get_tokenizer
|
||||||
|
except ImportError:
|
||||||
|
from backend_request_func import get_tokenizer
|
||||||
|
|
||||||
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
|
||||||
|
|
||||||
@@ -15,24 +56,157 @@ def test_prefix(llm=None, sampling_params=None, prompts=None):
|
|||||||
print(f"cost time {end_time - start_time}")
|
print(f"cost time {end_time - start_time}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class Request:
|
||||||
|
prompt: str
|
||||||
|
prompt_len: int
|
||||||
|
output_len: int
|
||||||
|
|
||||||
|
|
||||||
|
def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]:
|
||||||
|
vocab = tokenizer.get_vocab()
|
||||||
|
all_special_ids = set(tokenizer.all_special_ids)
|
||||||
|
|
||||||
|
# Remove the special tokens.
|
||||||
|
return random.choices(
|
||||||
|
[v for v in vocab.values() if v not in all_special_ids],
|
||||||
|
k=length,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests_from_dataset(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
input_length_range: tuple[int, int],
|
||||||
|
fixed_output_len: int | None,
|
||||||
|
) -> list[Request]:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [
|
||||||
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||||
|
for data in dataset
|
||||||
|
]
|
||||||
|
|
||||||
|
# Shuffle the dataset.
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
min_len, max_len = input_length_range
|
||||||
|
assert min_len >= 0 and max_len >= min_len, "input_length_range too small"
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_requests: list[Request] = []
|
||||||
|
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_requests) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt_token_ids = tokenizer(dataset[i][0]).input_ids
|
||||||
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = (
|
||||||
|
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||||
|
)
|
||||||
|
if min_len <= prompt_len <= max_len:
|
||||||
|
filtered_requests.append(Request(prompt, prompt_len, output_len))
|
||||||
|
|
||||||
|
return filtered_requests
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests_from_random(
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
input_length_range: tuple[int, int],
|
||||||
|
fixed_output_len: int | None,
|
||||||
|
prefix_len: int,
|
||||||
|
) -> list[Request]:
|
||||||
|
requests = []
|
||||||
|
prefix_token_ids = sample_tokens(tokenizer, prefix_len)
|
||||||
|
min_len, max_len = input_length_range
|
||||||
|
|
||||||
|
for i in range(num_requests):
|
||||||
|
unique_part_token_ids = sample_tokens(
|
||||||
|
tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len)
|
||||||
|
)
|
||||||
|
prompt_token_ids = prefix_token_ids + unique_part_token_ids
|
||||||
|
prompt = tokenizer.decode(prompt_token_ids)
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
assert min_len <= prompt_len <= max_len, (
|
||||||
|
f"prompt_len {prompt_len} out of range {min_len}:{max_len}"
|
||||||
|
)
|
||||||
|
requests.append(Request(prompt, prompt_len, fixed_output_len))
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_and_sort_requests(
|
||||||
|
requests: list[Request], repeat_count: int, sort: bool = False
|
||||||
|
) -> list[str]:
|
||||||
|
repeated_requests = requests * repeat_count
|
||||||
|
if sort:
|
||||||
|
repeated_requests.sort(key=lambda x: x[1])
|
||||||
|
else:
|
||||||
|
random.shuffle(repeated_requests)
|
||||||
|
return [req.prompt for req in repeated_requests]
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
llm = LLM(model=args.model,
|
tokenizer = get_tokenizer(args.model, trust_remote_code=True)
|
||||||
tokenizer_mode='auto',
|
input_length_range = tuple(map(int, args.input_length_range.split(":")))
|
||||||
trust_remote_code=True,
|
random.seed(args.seed)
|
||||||
enforce_eager=True,
|
if args.dataset_path is not None:
|
||||||
use_v2_block_manager=args.use_v2_block_manager,
|
if args.prefix_len > 0:
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
raise ValueError(
|
||||||
enable_prefix_caching=args.enable_prefix_caching)
|
"prefix-len is not supported when dataset-path is provided."
|
||||||
|
)
|
||||||
|
print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}")
|
||||||
|
filtered_requests = sample_requests_from_dataset(
|
||||||
|
dataset_path=args.dataset_path,
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_length_range=input_length_range,
|
||||||
|
fixed_output_len=args.output_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(f"Start to sample {args.num_prompts} prompts from random")
|
||||||
|
filtered_requests = sample_requests_from_random(
|
||||||
|
num_requests=args.num_prompts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_length_range=input_length_range,
|
||||||
|
fixed_output_len=args.output_len,
|
||||||
|
prefix_len=args.prefix_len,
|
||||||
|
)
|
||||||
|
|
||||||
num_prompts = 100
|
# Print some helpful stats of the requests.
|
||||||
prompts = [PROMPT] * num_prompts
|
print(f"Sampled {len(filtered_requests)} requests.")
|
||||||
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
prompt_lens = [req.prompt_len for req in filtered_requests]
|
||||||
|
print(f"Average input length: {sum(prompt_lens) / len(prompt_lens)}")
|
||||||
|
print(f"P50 input length: {sorted(prompt_lens)[len(prompt_lens) // 2]}")
|
||||||
|
print(f"Min Prompt Length: {min(prompt_lens)}")
|
||||||
|
print(f"Max Prompt Length: {max(prompt_lens)}")
|
||||||
|
|
||||||
print("------warm up------")
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
test_prefix(
|
|
||||||
llm=llm,
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
prompts=prompts,
|
|
||||||
sampling_params=sampling_params,
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=args.output_len,
|
||||||
|
detokenize=not args.disable_detokenize,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Testing filtered requests")
|
||||||
|
prompts = repeat_and_sort_requests(
|
||||||
|
filtered_requests, repeat_count=args.repeat_count, sort=args.sort
|
||||||
)
|
)
|
||||||
|
|
||||||
print("------start generating------")
|
print("------start generating------")
|
||||||
@@ -43,20 +217,61 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_argument_parser():
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance with or without "
|
||||||
|
"automatic prefix caching."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-path", type=str, default=None, help="Path to the dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument("--output-len", type=int, default=10)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-prompts",
|
||||||
|
type=int,
|
||||||
|
required=True,
|
||||||
|
help="Number of the prompts sampled from dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeat-count",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of times to repeat each prompt",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sort", action="store_true", help="Sort prompts by input length"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-length-range",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Range of input lengths for sampling prompts,"
|
||||||
|
'specified as "min:max" (e.g., "128:256").',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefix-len",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Specifies the length of a common prefix to be "
|
||||||
|
"added to the input prompt. The input-length-range will "
|
||||||
|
"subtract this length when filtering prompts. Only used "
|
||||||
|
"when dataset-path is not provided.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-detokenize",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Do not detokenize responses (i.e. do not include "
|
||||||
|
"detokenization time in the latency measurement)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = create_argument_parser()
|
||||||
description='Benchmark the performance with or without automatic '
|
|
||||||
'prefix caching.')
|
|
||||||
parser.add_argument('--model',
|
|
||||||
type=str,
|
|
||||||
default='baichuan-inc/Baichuan2-13B-Chat')
|
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
|
||||||
parser.add_argument('--output-len', type=int, default=10)
|
|
||||||
parser.add_argument('--enable-prefix-caching',
|
|
||||||
action='store_true',
|
|
||||||
help='enable prefix caching')
|
|
||||||
parser.add_argument('--use-v2-block-manager',
|
|
||||||
action='store_true',
|
|
||||||
help='Use BlockSpaceMangerV2')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
221
benchmarks/benchmark_prioritization.py
Normal file
221
benchmarks/benchmark_prioritization.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""Benchmark offline prioritization."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
# Select a equi-probable random priority
|
||||||
|
def get_random_flag():
|
||||||
|
return 0 if random.random() < 0.5 else 1
|
||||||
|
|
||||||
|
|
||||||
|
def sample_requests(
|
||||||
|
dataset_path: str,
|
||||||
|
num_requests: int,
|
||||||
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
|
fixed_output_len: int | None,
|
||||||
|
) -> list[tuple[str, int, int, int]]:
|
||||||
|
if fixed_output_len is not None and fixed_output_len < 4:
|
||||||
|
raise ValueError("output_len too small")
|
||||||
|
|
||||||
|
# Load the dataset.
|
||||||
|
with open(dataset_path) as f:
|
||||||
|
dataset = json.load(f)
|
||||||
|
# Filter out the conversations with less than 2 turns.
|
||||||
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
|
# Only keep the first two turns of each conversation.
|
||||||
|
dataset = [
|
||||||
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||||
|
for data in dataset
|
||||||
|
]
|
||||||
|
|
||||||
|
# Shuffle the dataset.
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
# Filter out sequences that are too long or too short
|
||||||
|
filtered_dataset: list[tuple[str, int, int]] = []
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
if len(filtered_dataset) == num_requests:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Tokenize the prompts and completions.
|
||||||
|
prompt = dataset[i][0]
|
||||||
|
prompt_token_ids = tokenizer(prompt).input_ids
|
||||||
|
completion = dataset[i][1]
|
||||||
|
completion_token_ids = tokenizer(completion).input_ids
|
||||||
|
prompt_len = len(prompt_token_ids)
|
||||||
|
output_len = (
|
||||||
|
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||||
|
)
|
||||||
|
if prompt_len < 4 or output_len < 4:
|
||||||
|
# Prune too short sequences.
|
||||||
|
continue
|
||||||
|
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||||
|
# Prune too long sequences.
|
||||||
|
continue
|
||||||
|
|
||||||
|
priority = get_random_flag()
|
||||||
|
|
||||||
|
filtered_dataset.append((prompt, prompt_len, output_len, priority))
|
||||||
|
|
||||||
|
return filtered_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def run_vllm(
|
||||||
|
requests: list[tuple[str, int, int]],
|
||||||
|
n: int,
|
||||||
|
engine_args: EngineArgs,
|
||||||
|
disable_detokenize: bool = False,
|
||||||
|
) -> float:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
llm.llm_engine.model_config.max_model_len >= (request[1] + request[2])
|
||||||
|
for request in requests
|
||||||
|
), (
|
||||||
|
"Please ensure that max_model_len is greater than the sum of"
|
||||||
|
" input_len and output_len for all requests."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add the requests to the engine.
|
||||||
|
prompts = []
|
||||||
|
sampling_params = []
|
||||||
|
priority = []
|
||||||
|
for prompt, _, output_len, _priority in requests:
|
||||||
|
prompts.append(prompt)
|
||||||
|
priority.append(_priority)
|
||||||
|
sampling_params.append(
|
||||||
|
SamplingParams(
|
||||||
|
n=n,
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=1.0,
|
||||||
|
ignore_eos=True,
|
||||||
|
max_tokens=output_len,
|
||||||
|
detokenize=not disable_detokenize,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
|
||||||
|
end = time.perf_counter()
|
||||||
|
return end - start
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
random.seed(args.seed)
|
||||||
|
|
||||||
|
# Sample the requests.
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
args.tokenizer, trust_remote_code=args.trust_remote_code
|
||||||
|
)
|
||||||
|
if args.dataset is None:
|
||||||
|
# Synthesize a prompt with the given input length.
|
||||||
|
prompt = "hi" * (args.input_len - 1)
|
||||||
|
requests = [
|
||||||
|
(prompt, args.input_len, args.output_len, get_random_flag())
|
||||||
|
for _ in range(args.num_prompts)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
requests = sample_requests(
|
||||||
|
args.dataset, args.num_prompts, tokenizer, args.output_len
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.backend == "vllm":
|
||||||
|
elapsed_time = run_vllm(
|
||||||
|
requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
|
total_num_tokens = sum(
|
||||||
|
prompt_len + output_len for _, prompt_len, output_len, priority in requests
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Output JSON results if specified
|
||||||
|
if args.output_json:
|
||||||
|
results = {
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"num_requests": len(requests),
|
||||||
|
"total_num_tokens": total_num_tokens,
|
||||||
|
"requests_per_second": len(requests) / elapsed_time,
|
||||||
|
"tokens_per_second": total_num_tokens / elapsed_time,
|
||||||
|
}
|
||||||
|
with open(args.output_json, "w") as f:
|
||||||
|
json.dump(results, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def create_argument_parser():
|
||||||
|
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset", type=str, default=None, help="Path to the dataset."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Input prompt length for each request",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Output length for each request. Overrides the "
|
||||||
|
"output length from the dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--n", type=int, default=1, help="Number of generated sequences per prompt."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-prompts", type=int, default=200, help="Number of prompts to process."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-json",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to save the throughput results in JSON format.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-detokenize",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Do not detokenize responses (i.e. do not include "
|
||||||
|
"detokenization time in the latency measurement)"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
parser = EngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = create_argument_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.tokenizer is None:
|
||||||
|
args.tokenizer = args.model
|
||||||
|
if args.dataset is None:
|
||||||
|
assert args.input_len is not None
|
||||||
|
assert args.output_len is not None
|
||||||
|
else:
|
||||||
|
assert args.input_len is None
|
||||||
|
|
||||||
|
main(args)
|
||||||
@@ -1,596 +1,17 @@
|
|||||||
"""Benchmark online serving throughput.
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
On the server side, run one of the following commands:
|
import sys
|
||||||
vLLM OpenAI API server
|
|
||||||
python -m vllm.entrypoints.openai.api_server \
|
|
||||||
--model <your_model> --swap-space 16 \
|
|
||||||
--disable-log-requests
|
|
||||||
|
|
||||||
(TGI backend)
|
|
||||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
|
||||||
|
|
||||||
On the client side, run:
|
|
||||||
python benchmarks/benchmark_serving.py \
|
|
||||||
--backend <backend> \
|
|
||||||
--model <your_model> \
|
|
||||||
--dataset-name sharegpt \
|
|
||||||
--dataset-path <path to dataset> \
|
|
||||||
--request-rate <request_rate> \ # By default <request_rate> is inf
|
|
||||||
--num-prompts <num_prompts> # By default <num_prompts> is 1000
|
|
||||||
"""
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
import warnings
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import AsyncGenerator, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput,
|
|
||||||
RequestFuncOutput)
|
|
||||||
from tqdm.asyncio import tqdm
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BenchmarkMetrics:
|
|
||||||
completed: int
|
|
||||||
total_input: int
|
|
||||||
total_output: int
|
|
||||||
request_throughput: float
|
|
||||||
input_throughput: float
|
|
||||||
output_throughput: float
|
|
||||||
mean_ttft_ms: float
|
|
||||||
median_ttft_ms: float
|
|
||||||
p99_ttft_ms: float
|
|
||||||
mean_tpot_ms: float
|
|
||||||
median_tpot_ms: float
|
|
||||||
p99_tpot_ms: float
|
|
||||||
|
|
||||||
|
|
||||||
def sample_sharegpt_requests(
|
|
||||||
dataset_path: str,
|
|
||||||
num_requests: int,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
fixed_output_len: Optional[int] = None,
|
|
||||||
) -> List[Tuple[str, int, int]]:
|
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
|
||||||
raise ValueError("output_len too small")
|
|
||||||
|
|
||||||
# Load the dataset.
|
|
||||||
with open(dataset_path) as f:
|
|
||||||
dataset = json.load(f)
|
|
||||||
# Filter out the conversations with less than 2 turns.
|
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
|
||||||
# Only keep the first two turns of each conversation.
|
|
||||||
dataset = [(data["conversations"][0]["value"],
|
|
||||||
data["conversations"][1]["value"]) for data in dataset]
|
|
||||||
|
|
||||||
# Shuffle the dataset.
|
|
||||||
random.shuffle(dataset)
|
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
if len(filtered_dataset) == num_requests:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
|
||||||
prompt = dataset[i][0]
|
|
||||||
prompt_token_ids = tokenizer(prompt).input_ids
|
|
||||||
completion = dataset[i][1]
|
|
||||||
completion_token_ids = tokenizer(completion).input_ids
|
|
||||||
prompt_len = len(prompt_token_ids)
|
|
||||||
output_len = len(completion_token_ids
|
|
||||||
) if fixed_output_len is None else fixed_output_len
|
|
||||||
if prompt_len < 4 or output_len < 4:
|
|
||||||
# Prune too short sequences.
|
|
||||||
continue
|
|
||||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
|
||||||
# Prune too long sequences.
|
|
||||||
continue
|
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
|
||||||
|
|
||||||
return filtered_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def sample_sonnet_requests(
|
|
||||||
dataset_path: str,
|
|
||||||
num_requests: int,
|
|
||||||
input_len: int,
|
|
||||||
output_len: int,
|
|
||||||
prefix_len: int,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
) -> List[Tuple[str, str, int, int]]:
|
|
||||||
assert (
|
|
||||||
input_len > prefix_len
|
|
||||||
), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'."
|
|
||||||
|
|
||||||
# Load the dataset.
|
|
||||||
with open(dataset_path) as f:
|
|
||||||
poem_lines = f.readlines()
|
|
||||||
|
|
||||||
# Tokenize the poem lines.
|
|
||||||
poem_token_ids = tokenizer(poem_lines).input_ids
|
|
||||||
average_poem_len = sum(
|
|
||||||
len(token_ids) for token_ids in poem_token_ids) / len(poem_token_ids)
|
|
||||||
|
|
||||||
# Base prefix for all requests.
|
|
||||||
base_prompt = "Pick as many lines as you can from these poem lines:\n"
|
|
||||||
base_message = [{
|
|
||||||
"role": "user",
|
|
||||||
"content": base_prompt,
|
|
||||||
}]
|
|
||||||
base_prompt_formatted = tokenizer.apply_chat_template(
|
|
||||||
base_message, add_generation_prompt=True, tokenize=False)
|
|
||||||
base_prompt_offset = len(tokenizer(base_prompt_formatted).input_ids)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
input_len > base_prompt_offset
|
|
||||||
), f"Please set 'args.sonnet-input-len' higher than {base_prompt_offset}."
|
|
||||||
num_input_lines = round(
|
|
||||||
(input_len - base_prompt_offset) / average_poem_len)
|
|
||||||
|
|
||||||
# First approximately `prefix_len` number of tokens in the
|
|
||||||
# prompt are fixed poem lines.
|
|
||||||
assert (
|
|
||||||
prefix_len > base_prompt_offset
|
|
||||||
), f"Please set 'args.sonnet-prefix-len' higher than {base_prompt_offset}."
|
|
||||||
|
|
||||||
num_prefix_lines = round(
|
|
||||||
(prefix_len - base_prompt_offset) / average_poem_len)
|
|
||||||
prefix_lines = poem_lines[:num_prefix_lines]
|
|
||||||
|
|
||||||
# Sample the rest of lines per request.
|
|
||||||
sampled_requests: List[Tuple[str, int, int]] = []
|
|
||||||
for _ in range(num_requests):
|
|
||||||
sampled_lines = "".join(
|
|
||||||
prefix_lines +
|
|
||||||
random.sample(poem_lines, num_input_lines - num_prefix_lines))
|
|
||||||
|
|
||||||
prompt = f"{base_prompt}{sampled_lines}"
|
|
||||||
message = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
prompt_formatted = tokenizer.apply_chat_template(
|
|
||||||
message, add_generation_prompt=True, tokenize=False)
|
|
||||||
prompt_len = len(tokenizer(prompt_formatted).input_ids)
|
|
||||||
sampled_requests.append(
|
|
||||||
(prompt, prompt_formatted, prompt_len, output_len))
|
|
||||||
|
|
||||||
return sampled_requests
|
|
||||||
|
|
||||||
|
|
||||||
async def get_request(
|
|
||||||
input_requests: List[Tuple[str, int, int]],
|
|
||||||
request_rate: float,
|
|
||||||
) -> AsyncGenerator[Tuple[str, int, int], None]:
|
|
||||||
input_requests = iter(input_requests)
|
|
||||||
for request in input_requests:
|
|
||||||
yield request
|
|
||||||
|
|
||||||
if request_rate == float("inf"):
|
|
||||||
# If the request rate is infinity, then we don't need to wait.
|
|
||||||
continue
|
|
||||||
# Sample the request interval from the exponential distribution.
|
|
||||||
interval = np.random.exponential(1.0 / request_rate)
|
|
||||||
# The next request will be sent after the interval.
|
|
||||||
await asyncio.sleep(interval)
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_metrics(
|
|
||||||
input_requests: List[Tuple[str, int, int]],
|
|
||||||
outputs: List[RequestFuncOutput],
|
|
||||||
dur_s: float,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
) -> Tuple[BenchmarkMetrics, List[int]]:
|
|
||||||
actual_output_lens = []
|
|
||||||
total_input = 0
|
|
||||||
completed = 0
|
|
||||||
tpots = []
|
|
||||||
ttfts = []
|
|
||||||
for i in range(len(outputs)):
|
|
||||||
if outputs[i].success:
|
|
||||||
output_len = len(tokenizer(outputs[i].generated_text).input_ids)
|
|
||||||
actual_output_lens.append(output_len)
|
|
||||||
total_input += input_requests[i][1]
|
|
||||||
if output_len > 1:
|
|
||||||
tpots.append(
|
|
||||||
(outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
|
||||||
ttfts.append(outputs[i].ttft)
|
|
||||||
completed += 1
|
|
||||||
else:
|
|
||||||
actual_output_lens.append(0)
|
|
||||||
|
|
||||||
metrics = BenchmarkMetrics(
|
|
||||||
completed=completed,
|
|
||||||
total_input=total_input,
|
|
||||||
total_output=sum(actual_output_lens),
|
|
||||||
request_throughput=completed / dur_s,
|
|
||||||
input_throughput=total_input / dur_s,
|
|
||||||
output_throughput=sum(actual_output_lens) / dur_s,
|
|
||||||
mean_ttft_ms=np.mean(ttfts or 0) *
|
|
||||||
1000, # ttfts is empty if streaming is not supported by backend
|
|
||||||
median_ttft_ms=np.median(ttfts or 0) * 1000,
|
|
||||||
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
|
|
||||||
mean_tpot_ms=np.mean(tpots) * 1000,
|
|
||||||
median_tpot_ms=np.median(tpots) * 1000,
|
|
||||||
p99_tpot_ms=np.percentile(tpots, 99) * 1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics, actual_output_lens
|
|
||||||
|
|
||||||
|
|
||||||
async def benchmark(
|
|
||||||
backend: str,
|
|
||||||
api_url: str,
|
|
||||||
model_id: str,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
input_requests: List[Tuple[str, int, int]],
|
|
||||||
best_of: int,
|
|
||||||
use_beam_search: bool,
|
|
||||||
request_rate: float,
|
|
||||||
disable_tqdm: bool,
|
|
||||||
):
|
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
|
||||||
request_func = ASYNC_REQUEST_FUNCS.get(backend)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown backend: {backend}")
|
|
||||||
|
|
||||||
print(f"Traffic request rate: {request_rate}")
|
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
|
||||||
tasks = []
|
|
||||||
async for request in get_request(input_requests, request_rate):
|
|
||||||
prompt, prompt_len, output_len = request
|
|
||||||
request_func_input = RequestFuncInput(
|
|
||||||
model=model_id,
|
|
||||||
prompt=prompt,
|
|
||||||
api_url=api_url,
|
|
||||||
prompt_len=prompt_len,
|
|
||||||
output_len=output_len,
|
|
||||||
best_of=best_of,
|
|
||||||
use_beam_search=use_beam_search,
|
|
||||||
)
|
|
||||||
tasks.append(
|
|
||||||
asyncio.create_task(
|
|
||||||
request_func(request_func_input=request_func_input,
|
|
||||||
pbar=pbar)))
|
|
||||||
outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
if not disable_tqdm:
|
|
||||||
pbar.close()
|
|
||||||
|
|
||||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
|
||||||
|
|
||||||
metrics, actual_output_lens = calculate_metrics(
|
|
||||||
input_requests=input_requests,
|
|
||||||
outputs=outputs,
|
|
||||||
dur_s=benchmark_duration,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='='))
|
|
||||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
|
||||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):",
|
|
||||||
benchmark_duration))
|
|
||||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
|
||||||
print("{:<40} {:<10}".format("Total generated tokens:",
|
|
||||||
metrics.total_output))
|
|
||||||
print("{:<40} {:<10.2f}".format("Request throughput (req/s):",
|
|
||||||
metrics.request_throughput))
|
|
||||||
print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):",
|
|
||||||
metrics.input_throughput))
|
|
||||||
print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):",
|
|
||||||
metrics.output_throughput))
|
|
||||||
print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-'))
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("Median TTFT (ms):",
|
|
||||||
metrics.median_ttft_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
|
|
||||||
print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)',
|
|
||||||
n=50,
|
|
||||||
c='-'))
|
|
||||||
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("Median TPOT (ms):",
|
|
||||||
metrics.median_tpot_ms))
|
|
||||||
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
|
|
||||||
print("=" * 50)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"duration": benchmark_duration,
|
|
||||||
"completed": metrics.completed,
|
|
||||||
"total_input_tokens": metrics.total_input,
|
|
||||||
"total_output_tokens": metrics.total_output,
|
|
||||||
"request_throughput": metrics.request_throughput,
|
|
||||||
"input_throughput": metrics.input_throughput,
|
|
||||||
"output_throughput": metrics.output_throughput,
|
|
||||||
"mean_ttft_ms": metrics.mean_ttft_ms,
|
|
||||||
"median_ttft_ms": metrics.median_ttft_ms,
|
|
||||||
"p99_ttft_ms": metrics.p99_ttft_ms,
|
|
||||||
"mean_tpot_ms": metrics.mean_tpot_ms,
|
|
||||||
"median_tpot_ms": metrics.median_tpot_ms,
|
|
||||||
"p99_tpot_ms": metrics.p99_tpot_ms,
|
|
||||||
"input_lens": [output.prompt_len for output in outputs],
|
|
||||||
"output_lens": actual_output_lens,
|
|
||||||
"ttfts": [output.ttft for output in outputs],
|
|
||||||
"itls": [output.itl for output in outputs],
|
|
||||||
"generated_texts": [output.generated_text for output in outputs],
|
|
||||||
"errors": [output.error for output in outputs],
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
|
||||||
print(args)
|
|
||||||
random.seed(args.seed)
|
|
||||||
np.random.seed(args.seed)
|
|
||||||
|
|
||||||
backend = args.backend
|
|
||||||
model_id = args.model
|
|
||||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
|
||||||
|
|
||||||
if args.base_url is not None:
|
|
||||||
api_url = f"{args.base_url}{args.endpoint}"
|
|
||||||
else:
|
|
||||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
|
||||||
|
|
||||||
tokenizer = get_tokenizer(tokenizer_id,
|
|
||||||
trust_remote_code=args.trust_remote_code)
|
|
||||||
|
|
||||||
if args.dataset is not None:
|
|
||||||
warnings.warn(
|
|
||||||
"The '--dataset' argument will be deprecated in the next "
|
|
||||||
"release. Please use '--dataset-name' and "
|
|
||||||
"'--dataset-path' in the future runs.",
|
|
||||||
stacklevel=2)
|
|
||||||
input_requests = sample_sharegpt_requests(
|
|
||||||
dataset_path=args.dataset,
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
fixed_output_len=args.sharegpt_output_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.dataset_name == "sharegpt":
|
|
||||||
input_requests = sample_sharegpt_requests(
|
|
||||||
dataset_path=args.dataset_path,
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
fixed_output_len=args.sharegpt_output_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif args.dataset_name == "sonnet":
|
|
||||||
# Do not format the prompt, pass to message directly
|
|
||||||
if args.backend == "openai-chat":
|
|
||||||
input_requests = sample_sonnet_requests(
|
|
||||||
dataset_path=args.dataset_path,
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
input_len=args.sonnet_input_len,
|
|
||||||
output_len=args.sonnet_output_len,
|
|
||||||
prefix_len=args.sonnet_prefix_len,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
input_requests = [(prompt, prompt_len, output_len)
|
|
||||||
for prompt, prompt_formatted, prompt_len,
|
|
||||||
output_len in input_requests]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
tokenizer.chat_template or tokenizer.default_chat_template
|
|
||||||
), "Tokenizer/model must have chat template for sonnet dataset."
|
|
||||||
input_requests = sample_sonnet_requests(
|
|
||||||
dataset_path=args.dataset_path,
|
|
||||||
num_requests=args.num_prompts,
|
|
||||||
input_len=args.sonnet_input_len,
|
|
||||||
output_len=args.sonnet_output_len,
|
|
||||||
prefix_len=args.sonnet_prefix_len,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
input_requests = [(prompt_formatted, prompt_len, output_len)
|
|
||||||
for prompt, prompt_formatted, prompt_len,
|
|
||||||
output_len in input_requests]
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
|
||||||
|
|
||||||
benchmark_result = asyncio.run(
|
|
||||||
benchmark(
|
|
||||||
backend=backend,
|
|
||||||
api_url=api_url,
|
|
||||||
model_id=model_id,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
input_requests=input_requests,
|
|
||||||
best_of=args.best_of,
|
|
||||||
use_beam_search=args.use_beam_search,
|
|
||||||
request_rate=args.request_rate,
|
|
||||||
disable_tqdm=args.disable_tqdm,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Save config and results to json
|
|
||||||
if args.save_result:
|
|
||||||
result_json = {}
|
|
||||||
|
|
||||||
# Setup
|
|
||||||
current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
|
|
||||||
result_json["date"] = current_dt
|
|
||||||
result_json["backend"] = backend
|
|
||||||
result_json["model_id"] = model_id
|
|
||||||
result_json["tokenizer_id"] = tokenizer_id
|
|
||||||
result_json["best_of"] = args.best_of
|
|
||||||
result_json["use_beam_search"] = args.use_beam_search
|
|
||||||
result_json["num_prompts"] = args.num_prompts
|
|
||||||
|
|
||||||
# Metadata
|
|
||||||
if args.metadata:
|
|
||||||
for item in args.metadata:
|
|
||||||
if "=" in item:
|
|
||||||
kvstring = item.split("=")
|
|
||||||
result_json[kvstring[0].strip()] = kvstring[1].strip()
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Invalid metadata format. Please use KEY=VALUE format."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Traffic
|
|
||||||
result_json["request_rate"] = (
|
|
||||||
args.request_rate if args.request_rate < float("inf") else "inf")
|
|
||||||
|
|
||||||
# Merge with benchmark result
|
|
||||||
result_json = {**result_json, **benchmark_result}
|
|
||||||
|
|
||||||
# Save to file
|
|
||||||
base_model_id = model_id.split("/")[-1]
|
|
||||||
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" #noqa
|
|
||||||
if args.result_dir:
|
|
||||||
file_name = os.path.join(args.result_dir, file_name)
|
|
||||||
with open(file_name, "w") as outfile:
|
|
||||||
json.dump(result_json, outfile)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||||
description="Benchmark the online serving throughput.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--backend",
|
|
||||||
type=str,
|
|
||||||
default="vllm",
|
|
||||||
choices=list(ASYNC_REQUEST_FUNCS.keys()),
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--base-url",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Server or API base url if not using http host and port.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
|
||||||
parser.add_argument(
|
|
||||||
"--endpoint",
|
|
||||||
type=str,
|
|
||||||
default="/v1/completions",
|
|
||||||
help="API endpoint.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to the ShareGPT dataset, will be deprecated in the "
|
|
||||||
"next release.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset-name",
|
|
||||||
type=str,
|
|
||||||
default="sharegpt",
|
|
||||||
choices=["sharegpt", "sonnet"],
|
|
||||||
help="Name of the dataset to benchmark on.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--dataset-path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to the dataset.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
type=str,
|
|
||||||
required=True,
|
|
||||||
help="Name of the model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokenizer",
|
|
||||||
type=str,
|
|
||||||
help=
|
|
||||||
"Name or path of the tokenizer, if not using the default tokenizer.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--best-of",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Generates `best_of` sequences per prompt and "
|
|
||||||
"returns the best one.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-prompts",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of prompts to process.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sharegpt-output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the output length "
|
|
||||||
"from the ShareGPT dataset.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sonnet-input-len",
|
|
||||||
type=int,
|
|
||||||
default=550,
|
|
||||||
help=
|
|
||||||
"Number of input tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sonnet-output-len",
|
|
||||||
type=int,
|
|
||||||
default=150,
|
|
||||||
help=
|
|
||||||
"Number of output tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sonnet-prefix-len",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help=
|
|
||||||
"Number of prefix tokens per request, used only for sonnet dataset.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--request-rate",
|
|
||||||
type=float,
|
|
||||||
default=float("inf"),
|
|
||||||
help="Number of requests per second. If this is inf, "
|
|
||||||
"then all the requests are sent at time 0. "
|
|
||||||
"Otherwise, we use Poisson process to synthesize "
|
|
||||||
"the request arrival times.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
parser.add_argument(
|
|
||||||
"--trust-remote-code",
|
|
||||||
action="store_true",
|
|
||||||
help="Trust remote code from huggingface",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-tqdm",
|
|
||||||
action="store_true",
|
|
||||||
help="Specify to disable tqdm progress bar.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--save-result",
|
|
||||||
action="store_true",
|
|
||||||
help="Specify to save benchmark results to a json file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--metadata",
|
|
||||||
metavar="KEY=VALUE",
|
|
||||||
nargs="*",
|
|
||||||
help="Key-value pairs (e.g, --metadata version=0.3.3 tp=1) "
|
|
||||||
"for metadata of this run to be saved in the result JSON file "
|
|
||||||
"for record keeping purposes.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--result-dir",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Specify directory to save benchmark json results."
|
|
||||||
"If not specified, results are saved in the current directory.",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
Please use the following command instead:
|
||||||
main(args)
|
vllm bench serve
|
||||||
|
|
||||||
|
For help with the new command, run:
|
||||||
|
vllm bench serve --help
|
||||||
|
|
||||||
|
Alternatively, you can run the new command directly with:
|
||||||
|
python -m vllm.entrypoints.cli.main bench serve --help
|
||||||
|
""")
|
||||||
|
sys.exit(1)
|
||||||
|
|||||||
1040
benchmarks/benchmark_serving_structured_output.py
Normal file
1040
benchmarks/benchmark_serving_structured_output.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,387 +1,17 @@
|
|||||||
"""Benchmark offline inference throughput."""
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
import argparse
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import json
|
import sys
|
||||||
import random
|
|
||||||
import time
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
|
||||||
PreTrainedTokenizerBase)
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
|
||||||
dataset_path: str,
|
|
||||||
num_requests: int,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
fixed_output_len: Optional[int],
|
|
||||||
) -> List[Tuple[str, int, int]]:
|
|
||||||
if fixed_output_len is not None and fixed_output_len < 4:
|
|
||||||
raise ValueError("output_len too small")
|
|
||||||
|
|
||||||
# Load the dataset.
|
|
||||||
with open(dataset_path) as f:
|
|
||||||
dataset = json.load(f)
|
|
||||||
# Filter out the conversations with less than 2 turns.
|
|
||||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
|
||||||
# Only keep the first two turns of each conversation.
|
|
||||||
dataset = [(data["conversations"][0]["value"],
|
|
||||||
data["conversations"][1]["value"]) for data in dataset]
|
|
||||||
|
|
||||||
# Shuffle the dataset.
|
|
||||||
random.shuffle(dataset)
|
|
||||||
|
|
||||||
# Filter out sequences that are too long or too short
|
|
||||||
filtered_dataset: List[Tuple[str, int, int]] = []
|
|
||||||
for i in range(len(dataset)):
|
|
||||||
if len(filtered_dataset) == num_requests:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
|
||||||
prompt = dataset[i][0]
|
|
||||||
prompt_token_ids = tokenizer(prompt).input_ids
|
|
||||||
completion = dataset[i][1]
|
|
||||||
completion_token_ids = tokenizer(completion).input_ids
|
|
||||||
prompt_len = len(prompt_token_ids)
|
|
||||||
output_len = len(completion_token_ids
|
|
||||||
) if fixed_output_len is None else fixed_output_len
|
|
||||||
if prompt_len < 4 or output_len < 4:
|
|
||||||
# Prune too short sequences.
|
|
||||||
continue
|
|
||||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
|
||||||
# Prune too long sequences.
|
|
||||||
continue
|
|
||||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
|
||||||
|
|
||||||
return filtered_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def run_vllm(
|
|
||||||
requests: List[Tuple[str, int, int]],
|
|
||||||
model: str,
|
|
||||||
tokenizer: str,
|
|
||||||
quantization: Optional[str],
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
seed: int,
|
|
||||||
n: int,
|
|
||||||
use_beam_search: bool,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
dtype: str,
|
|
||||||
max_model_len: Optional[int],
|
|
||||||
enforce_eager: bool,
|
|
||||||
kv_cache_dtype: str,
|
|
||||||
quantization_param_path: Optional[str],
|
|
||||||
device: str,
|
|
||||||
enable_prefix_caching: bool,
|
|
||||||
enable_chunked_prefill: bool,
|
|
||||||
max_num_batched_tokens: int,
|
|
||||||
gpu_memory_utilization: float = 0.9,
|
|
||||||
download_dir: Optional[str] = None,
|
|
||||||
) -> float:
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
llm = LLM(
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
quantization=quantization,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
seed=seed,
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
dtype=dtype,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
gpu_memory_utilization=gpu_memory_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
|
||||||
quantization_param_path=quantization_param_path,
|
|
||||||
device=device,
|
|
||||||
enable_prefix_caching=enable_prefix_caching,
|
|
||||||
download_dir=download_dir,
|
|
||||||
enable_chunked_prefill=enable_chunked_prefill,
|
|
||||||
max_num_batched_tokens=max_num_batched_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
|
||||||
prompts = []
|
|
||||||
sampling_params = []
|
|
||||||
for prompt, _, output_len in requests:
|
|
||||||
prompts.append(prompt)
|
|
||||||
sampling_params.append(
|
|
||||||
SamplingParams(
|
|
||||||
n=n,
|
|
||||||
temperature=0.0 if use_beam_search else 1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
use_beam_search=use_beam_search,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=output_len,
|
|
||||||
))
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
||||||
end = time.perf_counter()
|
|
||||||
return end - start
|
|
||||||
|
|
||||||
|
|
||||||
def run_hf(
|
|
||||||
requests: List[Tuple[str, int, int]],
|
|
||||||
model: str,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
n: int,
|
|
||||||
use_beam_search: bool,
|
|
||||||
max_batch_size: int,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
) -> float:
|
|
||||||
assert not use_beam_search
|
|
||||||
llm = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
|
||||||
if llm.config.model_type == "llama":
|
|
||||||
# To enable padding in the HF backend.
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
|
||||||
llm = llm.cuda()
|
|
||||||
|
|
||||||
pbar = tqdm(total=len(requests))
|
|
||||||
start = time.perf_counter()
|
|
||||||
batch: List[str] = []
|
|
||||||
max_prompt_len = 0
|
|
||||||
max_output_len = 0
|
|
||||||
for i in range(len(requests)):
|
|
||||||
prompt, prompt_len, output_len = requests[i]
|
|
||||||
# Add the prompt to the batch.
|
|
||||||
batch.append(prompt)
|
|
||||||
max_prompt_len = max(max_prompt_len, prompt_len)
|
|
||||||
max_output_len = max(max_output_len, output_len)
|
|
||||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
|
||||||
# Check if we can add more requests to the batch.
|
|
||||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
|
||||||
if (max(max_prompt_len, next_prompt_len) +
|
|
||||||
max(max_output_len, next_output_len)) <= 2048:
|
|
||||||
# We can add more requests to the batch.
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Generate the sequences.
|
|
||||||
input_ids = tokenizer(batch, return_tensors="pt",
|
|
||||||
padding=True).input_ids
|
|
||||||
llm_outputs = llm.generate(
|
|
||||||
input_ids=input_ids.cuda(),
|
|
||||||
do_sample=not use_beam_search,
|
|
||||||
num_return_sequences=n,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
use_cache=True,
|
|
||||||
max_new_tokens=max_output_len,
|
|
||||||
)
|
|
||||||
# Include the decoding time.
|
|
||||||
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
|
|
||||||
pbar.update(len(batch))
|
|
||||||
|
|
||||||
# Clear the batch.
|
|
||||||
batch = []
|
|
||||||
max_prompt_len = 0
|
|
||||||
max_output_len = 0
|
|
||||||
end = time.perf_counter()
|
|
||||||
return end - start
|
|
||||||
|
|
||||||
|
|
||||||
def run_mii(
|
|
||||||
requests: List[Tuple[str, int, int]],
|
|
||||||
model: str,
|
|
||||||
tensor_parallel_size: int,
|
|
||||||
output_len: int,
|
|
||||||
) -> float:
|
|
||||||
from mii import client, serve
|
|
||||||
llm = serve(model, tensor_parallel=tensor_parallel_size)
|
|
||||||
prompts = [prompt for prompt, _, _ in requests]
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
llm.generate(prompts, max_new_tokens=output_len)
|
|
||||||
end = time.perf_counter()
|
|
||||||
client = client(model)
|
|
||||||
client.terminate_server()
|
|
||||||
return end - start
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
|
||||||
print(args)
|
|
||||||
random.seed(args.seed)
|
|
||||||
|
|
||||||
# Sample the requests.
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
|
||||||
if args.dataset is None:
|
|
||||||
# Synthesize a prompt with the given input length.
|
|
||||||
prompt = "hi" * (args.input_len - 1)
|
|
||||||
requests = [(prompt, args.input_len, args.output_len)
|
|
||||||
for _ in range(args.num_prompts)]
|
|
||||||
else:
|
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
|
|
||||||
args.output_len)
|
|
||||||
|
|
||||||
if args.backend == "vllm":
|
|
||||||
elapsed_time = run_vllm(
|
|
||||||
requests, args.model, args.tokenizer, args.quantization,
|
|
||||||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
|
||||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
|
||||||
args.enforce_eager, args.kv_cache_dtype,
|
|
||||||
args.quantization_param_path, args.device,
|
|
||||||
args.enable_prefix_caching, args.enable_chunked_prefill,
|
|
||||||
args.max_num_batched_tokens, args.gpu_memory_utilization,
|
|
||||||
args.download_dir)
|
|
||||||
elif args.backend == "hf":
|
|
||||||
assert args.tensor_parallel_size == 1
|
|
||||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
|
||||||
args.use_beam_search, args.hf_max_batch_size,
|
|
||||||
args.trust_remote_code)
|
|
||||||
elif args.backend == "mii":
|
|
||||||
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
|
|
||||||
args.output_len)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
|
||||||
total_num_tokens = sum(prompt_len + output_len
|
|
||||||
for _, prompt_len, output_len in requests)
|
|
||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
|
||||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
print("""DEPRECATED: This script has been moved to the vLLM CLI.
|
||||||
parser.add_argument("--backend",
|
|
||||||
type=str,
|
|
||||||
choices=["vllm", "hf", "mii"],
|
|
||||||
default="vllm")
|
|
||||||
parser.add_argument("--dataset",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to the dataset.")
|
|
||||||
parser.add_argument("--input-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Input prompt length for each request")
|
|
||||||
parser.add_argument("--output-len",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Output length for each request. Overrides the "
|
|
||||||
"output length from the dataset.")
|
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
|
||||||
parser.add_argument('--quantization',
|
|
||||||
'-q',
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=None)
|
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
|
||||||
parser.add_argument("--n",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of generated sequences per prompt.")
|
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
|
||||||
parser.add_argument("--num-prompts",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Number of prompts to process.")
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
parser.add_argument("--hf-max-batch-size",
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help="Maximum batch size for HF backend.")
|
|
||||||
parser.add_argument('--trust-remote-code',
|
|
||||||
action='store_true',
|
|
||||||
help='trust remote code from huggingface')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max-model-len',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='Maximum length of a sequence (including prompt and output). '
|
|
||||||
'If None, will be derived from the model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
|
|
||||||
help='data type for model weights and activations. '
|
|
||||||
'The "auto" option will use FP16 precision '
|
|
||||||
'for FP32 and FP16 models, and BF16 precision '
|
|
||||||
'for BF16 models.')
|
|
||||||
parser.add_argument('--gpu-memory-utilization',
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help='the fraction of GPU memory to be used for '
|
|
||||||
'the model executor, which can range from 0 to 1.'
|
|
||||||
'If unspecified, will use the default value of 0.9.')
|
|
||||||
parser.add_argument("--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
help="enforce eager execution")
|
|
||||||
parser.add_argument(
|
|
||||||
"--kv-cache-dtype",
|
|
||||||
type=str,
|
|
||||||
choices=["auto", "fp8"],
|
|
||||||
default="auto",
|
|
||||||
help=
|
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
|
||||||
'common inference criteria.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--quantization-param-path',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Path to the JSON file containing the KV cache scaling factors. '
|
|
||||||
'This should generally be supplied, when KV cache dtype is FP8. '
|
|
||||||
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
|
|
||||||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
|
|
||||||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
|
||||||
'instead supported for common inference criteria.')
|
|
||||||
parser.add_argument(
|
|
||||||
"--device",
|
|
||||||
type=str,
|
|
||||||
default="cuda",
|
|
||||||
choices=["cuda", "cpu"],
|
|
||||||
help='device type for vLLM execution, supporting CUDA and CPU.')
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-prefix-caching",
|
|
||||||
action='store_true',
|
|
||||||
help="enable automatic prefix caching for vLLM backend.")
|
|
||||||
parser.add_argument("--enable-chunked-prefill",
|
|
||||||
action='store_true',
|
|
||||||
help="enable chunked prefill for vLLM backend.")
|
|
||||||
parser.add_argument('--max-num-batched-tokens',
|
|
||||||
type=int,
|
|
||||||
default=None,
|
|
||||||
help='maximum number of batched tokens per '
|
|
||||||
'iteration')
|
|
||||||
parser.add_argument('--download-dir',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='directory to download and load the weights, '
|
|
||||||
'default to the default cache dir of huggingface')
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.tokenizer is None:
|
|
||||||
args.tokenizer = args.model
|
|
||||||
if args.dataset is None:
|
|
||||||
assert args.input_len is not None
|
|
||||||
assert args.output_len is not None
|
|
||||||
else:
|
|
||||||
assert args.input_len is None
|
|
||||||
|
|
||||||
if args.backend == "vllm":
|
Please use the following command instead:
|
||||||
if args.hf_max_batch_size is not None:
|
vllm bench throughput
|
||||||
raise ValueError("HF max batch size is only for HF backend.")
|
|
||||||
elif args.backend == "hf":
|
For help with the new command, run:
|
||||||
if args.hf_max_batch_size is None:
|
vllm bench throughput --help
|
||||||
raise ValueError("HF max batch size is required for HF backend.")
|
|
||||||
if args.quantization is not None:
|
Alternatively, you can run the new command directly with:
|
||||||
raise ValueError("Quantization is only for vLLM backend.")
|
python -m vllm.entrypoints.cli.main bench throughput --help
|
||||||
elif args.backend == "mii":
|
""")
|
||||||
if args.dtype != "auto":
|
sys.exit(1)
|
||||||
raise ValueError("dtype must be auto for MII backend.")
|
|
||||||
if args.n != 1:
|
|
||||||
raise ValueError("n must be 1 for MII backend.")
|
|
||||||
if args.use_beam_search:
|
|
||||||
raise ValueError("Beam search is not supported for MII backend.")
|
|
||||||
if args.quantization is not None:
|
|
||||||
raise ValueError("Quantization is only for vLLM backend.")
|
|
||||||
if args.hf_max_batch_size is not None:
|
|
||||||
raise ValueError("HF max batch size is only for HF backend.")
|
|
||||||
if args.tokenizer != args.model:
|
|
||||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
|
||||||
"backend.")
|
|
||||||
main(args)
|
|
||||||
|
|||||||
125
benchmarks/benchmark_utils.py
Normal file
125
benchmarks/benchmark_utils.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from types import TracebackType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_pytorch_benchmark_format(
|
||||||
|
args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any]
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Save the benchmark results in the format used by PyTorch OSS benchmark with
|
||||||
|
on metric per record
|
||||||
|
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
|
||||||
|
"""
|
||||||
|
records = []
|
||||||
|
if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False):
|
||||||
|
return records
|
||||||
|
|
||||||
|
for name, benchmark_values in metrics.items():
|
||||||
|
record = {
|
||||||
|
"benchmark": {
|
||||||
|
"name": "vLLM benchmark",
|
||||||
|
"extra_info": {
|
||||||
|
"args": vars(args),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"name": args.model,
|
||||||
|
},
|
||||||
|
"metric": {
|
||||||
|
"name": name,
|
||||||
|
"benchmark_values": benchmark_values,
|
||||||
|
"extra_info": extra_info,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size")
|
||||||
|
# Save tensor_parallel_size parameter if it's part of the metadata
|
||||||
|
if not tp and "tensor_parallel_size" in extra_info:
|
||||||
|
record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = (
|
||||||
|
extra_info["tensor_parallel_size"]
|
||||||
|
)
|
||||||
|
|
||||||
|
records.append(record)
|
||||||
|
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
class InfEncoder(json.JSONEncoder):
|
||||||
|
def clear_inf(self, o: Any):
|
||||||
|
if isinstance(o, dict):
|
||||||
|
return {k: self.clear_inf(v) for k, v in o.items()}
|
||||||
|
elif isinstance(o, list):
|
||||||
|
return [self.clear_inf(v) for v in o]
|
||||||
|
elif isinstance(o, float) and math.isinf(o):
|
||||||
|
return "inf"
|
||||||
|
return o
|
||||||
|
|
||||||
|
def iterencode(self, o: Any, *args, **kwargs) -> Any:
|
||||||
|
return super().iterencode(self.clear_inf(o), *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_json(filename: str, records: list) -> None:
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump(
|
||||||
|
records,
|
||||||
|
f,
|
||||||
|
cls=InfEncoder,
|
||||||
|
default=lambda o: f"<{type(o).__name__} object is not JSON serializable>",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Collect time and generate time metrics
|
||||||
|
#
|
||||||
|
# Example Usage:
|
||||||
|
# collector = TimeCollector(TimeCollector.US)
|
||||||
|
# for _ in range(total_iteration):
|
||||||
|
# with collector:
|
||||||
|
# ...
|
||||||
|
# collector.dump_avg_max()
|
||||||
|
class TimeCollector:
|
||||||
|
NS: int = 1
|
||||||
|
US: int = NS * 1000
|
||||||
|
MS: int = US * 1000
|
||||||
|
S: int = MS * 1000
|
||||||
|
|
||||||
|
def __init__(self, scale: int) -> None:
|
||||||
|
self.cnt: int = 0
|
||||||
|
self._sum: int = 0
|
||||||
|
self._max: int | None = None
|
||||||
|
self.scale = scale
|
||||||
|
self.start_time: int = time.monotonic_ns()
|
||||||
|
|
||||||
|
def collect(self, v: int) -> None:
|
||||||
|
self.cnt += 1
|
||||||
|
self._sum += v
|
||||||
|
if self._max is None:
|
||||||
|
self._max = v
|
||||||
|
else:
|
||||||
|
self._max = max(self._max, v)
|
||||||
|
|
||||||
|
def avg(self) -> float | str:
|
||||||
|
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
|
||||||
|
|
||||||
|
def max(self) -> float | str:
|
||||||
|
return self._max / self.scale if self._max else "N/A"
|
||||||
|
|
||||||
|
def dump_avg_max(self) -> list[float | str]:
|
||||||
|
return [self.avg(), self.max()]
|
||||||
|
|
||||||
|
def __enter__(self) -> None:
|
||||||
|
self.start_time = time.monotonic_ns()
|
||||||
|
|
||||||
|
def __exit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_value: BaseException | None,
|
||||||
|
exc_traceback: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
self.collect(time.monotonic_ns() - self.start_time)
|
||||||
515
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
515
benchmarks/cutlass_benchmarks/sparse_benchmarks.py
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from utils import make_rand_sparse_tensors
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
|
||||||
|
# bench
|
||||||
|
def bench_fn(
|
||||||
|
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
||||||
|
) -> TMeasurement:
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"args": args,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(*args, **kwargs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_int8(
|
||||||
|
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.int8
|
||||||
|
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(
|
||||||
|
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
||||||
|
)
|
||||||
|
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
|
if not torch.allclose(out, out_ref):
|
||||||
|
print("Incorrect results")
|
||||||
|
print(out)
|
||||||
|
print(out_ref)
|
||||||
|
else:
|
||||||
|
print("Correct results")
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
# pytorch impl - bfloat16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
|
torch.mm,
|
||||||
|
a.to(dtype=torch.bfloat16),
|
||||||
|
b.to(dtype=torch.bfloat16),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pytorch impl - float16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp16_fp16_fp16_matmul-no-scales",
|
||||||
|
torch.mm,
|
||||||
|
a.to(dtype=torch.float16),
|
||||||
|
b.to(dtype=torch.float16),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm",
|
||||||
|
ops.cutlass_scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_bias",
|
||||||
|
ops.cutlass_scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass sparse impl
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_i8_i8_bf16_scaled_sparse_mm",
|
||||||
|
ops.cutlass_scaled_sparse_mm,
|
||||||
|
a,
|
||||||
|
b_compressed,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass sparse with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
|
||||||
|
ops.cutlass_scaled_sparse_mm,
|
||||||
|
a,
|
||||||
|
b_compressed,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench_fp8(
|
||||||
|
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
assert dtype == torch.float8_e4m3fn
|
||||||
|
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
out = ops.cutlass_scaled_sparse_mm(
|
||||||
|
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
|
||||||
|
)
|
||||||
|
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
|
||||||
|
|
||||||
|
if not torch.allclose(out, out_ref):
|
||||||
|
print("Incorrect results")
|
||||||
|
print(out)
|
||||||
|
print(out_ref)
|
||||||
|
else:
|
||||||
|
print("Correct results")
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
|
||||||
|
# pytorch impl w. bf16
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales",
|
||||||
|
torch.mm,
|
||||||
|
a.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
|
b.to(dtype=torch.bfloat16, device="cuda"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pytorch impl: bf16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
use_fast_accum=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, without fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# pytorch impl: fp16 output, with fp8 fast accum
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
|
||||||
|
torch._scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=torch.float16,
|
||||||
|
use_fast_accum=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm",
|
||||||
|
ops.cutlass_scaled_mm,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
|
||||||
|
ops.cutlass_scaled_sparse_mm,
|
||||||
|
a,
|
||||||
|
b_compressed,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass impl: fp16 output
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
|
||||||
|
ops.cutlass_scaled_sparse_mm,
|
||||||
|
a,
|
||||||
|
b_compressed,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.float16,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass impl: bf16 output, with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
|
||||||
|
ops.cutlass_scaled_sparse_mm,
|
||||||
|
a,
|
||||||
|
b_compressed,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.bfloat16,
|
||||||
|
bias,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass impl: fp16 output, with bias
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
|
||||||
|
ops.cutlass_scaled_sparse_mm,
|
||||||
|
a,
|
||||||
|
b_compressed,
|
||||||
|
e,
|
||||||
|
scale_a,
|
||||||
|
scale_b,
|
||||||
|
torch.float16,
|
||||||
|
bias.to(dtype=torch.float16),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench(
|
||||||
|
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return bench_int8(dtype, m, k, n, label, sub_label)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return bench_fp8(dtype, m, k, n, label, sub_label)
|
||||||
|
raise ValueError("unsupported type")
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
results = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# output makers
|
||||||
|
def make_output(
|
||||||
|
data: Iterable[TMeasurement],
|
||||||
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None,
|
||||||
|
):
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
# argparse runners
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||||
|
n = len(dim_sizes)
|
||||||
|
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||||
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args.dtype, MKNs)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_data.extend(d)
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Cutlass GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
required=True,
|
||||||
|
help="Available options are ['int8', 'fp8']",
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||||
|
)
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
100
benchmarks/cutlass_benchmarks/utils.py
Normal file
100
benchmarks/cutlass_benchmarks/utils.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Cutlass bench utils
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
return tensor.to(dtype=torch.float16)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_tensors(
|
||||||
|
dtype: torch.dtype, m: int, n: int, k: int
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
a = torch.randn((m, k), device="cuda") * 5
|
||||||
|
b = torch.randn((n, k), device="cuda").t() * 5
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return to_int8(a), to_int8(b)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return to_fp8(a), to_fp8(b)
|
||||||
|
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
|
||||||
|
def prune_to_2_4(tensor):
|
||||||
|
# Reshape tensor to [N, 4] where N is number of groups of 4
|
||||||
|
original_shape = tensor.shape
|
||||||
|
reshaped = tensor.reshape(-1, 4)
|
||||||
|
|
||||||
|
# Get indices of top 2 absolute values in each group of 4
|
||||||
|
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
|
||||||
|
|
||||||
|
# Create binary mask
|
||||||
|
mask = torch.zeros_like(reshaped)
|
||||||
|
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
|
||||||
|
|
||||||
|
# Apply mask and reshape back
|
||||||
|
pruned = reshaped * mask
|
||||||
|
|
||||||
|
# Turn all -0.0 to 0.0
|
||||||
|
pruned[pruned == -0.0] = 0.0
|
||||||
|
|
||||||
|
return pruned.reshape(original_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def make_rand_sparse_tensors(
|
||||||
|
dtype: torch.dtype, m: int, n: int, k: int
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
a = torch.randn((m, k), device="cuda") * 5
|
||||||
|
b = torch.randn((n, k), device="cuda").t() * 5
|
||||||
|
|
||||||
|
b = prune_to_2_4(b.t()).t()
|
||||||
|
|
||||||
|
if dtype == torch.int8:
|
||||||
|
a, b = to_int8(a), to_int8(b)
|
||||||
|
elif dtype == torch.float8_e4m3fn:
|
||||||
|
a, b = to_fp8(a), to_fp8(b)
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
a, b = to_fp16(a), to_fp16(b)
|
||||||
|
elif dtype == torch.bfloat16:
|
||||||
|
a, b = to_bf16(a), to_bf16(b)
|
||||||
|
else:
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
b_compressed, e = ops.cutlass_sparse_compress(b.t())
|
||||||
|
|
||||||
|
# Compressed B, Metadata, Original A, B
|
||||||
|
return b_compressed, e, a, b
|
||||||
|
|
||||||
|
|
||||||
|
def make_n_rand_sparse_tensors(
|
||||||
|
num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int
|
||||||
|
) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]:
|
||||||
|
ABs = []
|
||||||
|
for _ in range(num_tensors):
|
||||||
|
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
|
||||||
|
if b_comp is not None:
|
||||||
|
ABs.append(make_rand_sparse_tensors(dtype, m, n, k))
|
||||||
|
BComps, Es, As, Bs = zip(*ABs)
|
||||||
|
return list(BComps), list(Es), list(As), list(Bs)
|
||||||
372
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
372
benchmarks/cutlass_benchmarks/w8a8_benchmarks.py
Normal file
@@ -0,0 +1,372 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from utils import make_rand_tensors
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
w8a8_triton_block_scaled_mm,
|
||||||
|
)
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
|
|
||||||
|
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
|
||||||
|
# bench
|
||||||
|
def bench_fn(
|
||||||
|
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
|
||||||
|
) -> TMeasurement:
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"args": args,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(*args, **kwargs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_int8(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
bench_kernels: list[str] | None = None,
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
"""Benchmark INT8-based kernels."""
|
||||||
|
assert dtype == torch.int8
|
||||||
|
a, b = make_rand_tensors(torch.int8, m, n, k)
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||||
|
azp = torch.zeros((m,), device="cuda", dtype=torch.int32)
|
||||||
|
azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32)
|
||||||
|
|
||||||
|
bench_fns = {
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
||||||
|
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||||
|
),
|
||||||
|
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
||||||
|
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
||||||
|
),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16
|
||||||
|
),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16, bias
|
||||||
|
),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16, azp_adj
|
||||||
|
),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias
|
||||||
|
),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp
|
||||||
|
),
|
||||||
|
"cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
for name, fn in bench_fns.items():
|
||||||
|
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||||
|
if bench_kernels is None or name in bench_kernels:
|
||||||
|
print(f"Running {name}")
|
||||||
|
timers.append(bench_fn(label, sub_label, name, fn))
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench_fp8(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
bench_kernels: list[str] | None = None,
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
"""Benchmark FP8-based kernels."""
|
||||||
|
assert dtype == torch.float8_e4m3fn
|
||||||
|
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
|
||||||
|
a_cont = a.contiguous()
|
||||||
|
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
|
||||||
|
block_scale_b = torch.rand(
|
||||||
|
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
|
||||||
|
)
|
||||||
|
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||||
|
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||||
|
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
print(m, k, n)
|
||||||
|
|
||||||
|
bench_fns = {
|
||||||
|
"pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm(
|
||||||
|
a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16)
|
||||||
|
),
|
||||||
|
"pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm(
|
||||||
|
a.to(dtype=torch.float16), b.to(dtype=torch.float16)
|
||||||
|
),
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, out_dtype=torch.float16
|
||||||
|
),
|
||||||
|
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True
|
||||||
|
),
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, out_dtype=torch.bfloat16
|
||||||
|
),
|
||||||
|
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True
|
||||||
|
),
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16
|
||||||
|
),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, torch.float16
|
||||||
|
),
|
||||||
|
"cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, torch.bfloat16, bias
|
||||||
|
),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16)
|
||||||
|
),
|
||||||
|
"triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_triton_block_scaled_mm(
|
||||||
|
a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128)
|
||||||
|
),
|
||||||
|
"cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm(
|
||||||
|
a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
for name, fn in bench_fns.items():
|
||||||
|
# If bench_kernels is None, run all. Otherwise, run only exact matches.
|
||||||
|
if bench_kernels is None or name in bench_kernels:
|
||||||
|
print(f"Running {name}")
|
||||||
|
timers.append(bench_fn(label, sub_label, name, fn))
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
def bench(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
bench_kernels: list[str] | None = None,
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
if dtype == torch.int8:
|
||||||
|
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
return bench_fp8(dtype, m, k, n, label, sub_label, bench_kernels)
|
||||||
|
raise ValueError("unsupported type")
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
|
bench_kernels: list[str] | None = None,
|
||||||
|
) -> Iterable[TMeasurement]:
|
||||||
|
results = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(
|
||||||
|
dtype,
|
||||||
|
m,
|
||||||
|
k,
|
||||||
|
n,
|
||||||
|
f"scaled-{dtype}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})",
|
||||||
|
bench_kernels=bench_kernels,
|
||||||
|
)
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def make_output(
|
||||||
|
data: Iterable[TMeasurement],
|
||||||
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None,
|
||||||
|
):
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
|
||||||
|
n = len(dim_sizes)
|
||||||
|
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
|
||||||
|
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
|
||||||
|
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
|
||||||
|
MKNs = list(zip(Ms, Ks, Ns))
|
||||||
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args.dtype, MKNs, bench_kernels=args.kernels)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestamp = int(time.time())
|
||||||
|
|
||||||
|
all_data = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_data.extend(d)
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(all_data, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError("unsupported dtype")
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Cutlass GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/cutlass_benchmarks/w8a8_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
required=True,
|
||||||
|
help="Available options are ['int8', 'fp8']",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kernels",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Exact names of the kernels to benchmark. If not set, runs all kernels.",
|
||||||
|
)
|
||||||
|
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd")
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
range_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
range_parser.add_argument("--m-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--n-constant", type=int, default=None)
|
||||||
|
range_parser.add_argument("--k-constant", type=int, default=None)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||||
|
)
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.func(args)
|
||||||
46
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
46
benchmarks/cutlass_benchmarks/weight_shapes.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Weight Shapes are in the format
|
||||||
|
# ([K, N], TP_SPLIT_DIM)
|
||||||
|
# Example:
|
||||||
|
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 14336, N = 4096
|
||||||
|
# - TP2 : K = 7168, N = 4096
|
||||||
|
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 4096, N = 6144
|
||||||
|
# - TP4 : K = 4096, N = 1536
|
||||||
|
|
||||||
|
# TP1 shapes
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf": [
|
||||||
|
([4096, 12288], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 22016], 1),
|
||||||
|
([11008, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3-8b": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf": [
|
||||||
|
([5120, 15360], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 27648], 1),
|
||||||
|
([13824, 5120], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
143
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
Normal file
143
benchmarks/disagg_benchmarks/disagg_overhead_benchmark.sh
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# benchmark the overhead of disaggregated prefill.
|
||||||
|
# methodology:
|
||||||
|
# - send all request to prefill vLLM instance. It will buffer KV cache.
|
||||||
|
# - then send all request to decode instance.
|
||||||
|
# - The TTFT of decode instance is the overhead.
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
kill_gpu_processes() {
|
||||||
|
# kill all processes on GPU.
|
||||||
|
pgrep pt_main_thread | xargs -r kill -9
|
||||||
|
pgrep python3 | xargs -r kill -9
|
||||||
|
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||||
|
pgrep VLLM | xargs -r kill -9
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
# remove vllm config file
|
||||||
|
rm -rf ~/.config/vllm
|
||||||
|
|
||||||
|
# Print the GPU memory usage
|
||||||
|
# so that we know if all GPU processes are killed.
|
||||||
|
gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0)
|
||||||
|
# The memory usage should be 0 MB.
|
||||||
|
echo "GPU 0 Memory Usage: $gpu_memory_usage MB"
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_server() {
|
||||||
|
# wait for vllm server to start
|
||||||
|
# return 1 if vllm server crashes
|
||||||
|
local port=$1
|
||||||
|
timeout 1200 bash -c "
|
||||||
|
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done" && return 0 || return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
benchmark() {
|
||||||
|
|
||||||
|
export VLLM_LOGGING_LEVEL=DEBUG
|
||||||
|
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||||
|
|
||||||
|
# compare chunked prefill with disaggregated prefill
|
||||||
|
|
||||||
|
results_folder="./results"
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
dataset_name="sonnet"
|
||||||
|
dataset_path="../sonnet_4x.txt"
|
||||||
|
num_prompts=10
|
||||||
|
qps=$1
|
||||||
|
prefix_len=50
|
||||||
|
input_len=2048
|
||||||
|
output_len=$2
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||||
|
--port 8100 \
|
||||||
|
--max-model-len 10000 \
|
||||||
|
--gpu-memory-utilization 0.6 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||||
|
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||||
|
--port 8200 \
|
||||||
|
--max-model-len 10000 \
|
||||||
|
--gpu-memory-utilization 0.6 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||||
|
|
||||||
|
wait_for_server 8100
|
||||||
|
wait_for_server 8200
|
||||||
|
|
||||||
|
# let the prefill instance finish prefill
|
||||||
|
vllm bench serve \
|
||||||
|
--backend vllm \
|
||||||
|
--model $model \
|
||||||
|
--dataset-name $dataset_name \
|
||||||
|
--dataset-path $dataset_path \
|
||||||
|
--sonnet-input-len $input_len \
|
||||||
|
--sonnet-output-len "$output_len" \
|
||||||
|
--sonnet-prefix-len $prefix_len \
|
||||||
|
--num-prompts $num_prompts \
|
||||||
|
--port 8100 \
|
||||||
|
--save-result \
|
||||||
|
--result-dir $results_folder \
|
||||||
|
--result-filename disagg_prefill_tp1.json \
|
||||||
|
--request-rate "inf"
|
||||||
|
|
||||||
|
|
||||||
|
# send the request to decode.
|
||||||
|
# The TTFT of this command will be the overhead of disagg prefill impl.
|
||||||
|
vllm bench serve \
|
||||||
|
--backend vllm \
|
||||||
|
--model $model \
|
||||||
|
--dataset-name $dataset_name \
|
||||||
|
--dataset-path $dataset_path \
|
||||||
|
--sonnet-input-len $input_len \
|
||||||
|
--sonnet-output-len "$output_len" \
|
||||||
|
--sonnet-prefix-len $prefix_len \
|
||||||
|
--num-prompts $num_prompts \
|
||||||
|
--port 8200 \
|
||||||
|
--save-result \
|
||||||
|
--result-dir $results_folder \
|
||||||
|
--result-filename disagg_prefill_tp1_overhead.json \
|
||||||
|
--request-rate "$qps"
|
||||||
|
kill_gpu_processes
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
main() {
|
||||||
|
|
||||||
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
(which jq) || (apt-get -y install jq)
|
||||||
|
(which socat) || (apt-get -y install socat)
|
||||||
|
|
||||||
|
pip install quart httpx datasets
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
# create sonnet-4x.txt
|
||||||
|
echo "" > sonnet_4x.txt
|
||||||
|
for _ in {1..4}
|
||||||
|
do
|
||||||
|
cat sonnet.txt >> sonnet_4x.txt
|
||||||
|
done
|
||||||
|
cd disagg_benchmarks
|
||||||
|
|
||||||
|
rm -rf results
|
||||||
|
mkdir results
|
||||||
|
|
||||||
|
default_qps=1
|
||||||
|
default_output_len=1
|
||||||
|
benchmark $default_qps $default_output_len
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
main "$@"
|
||||||
157
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
Normal file
157
benchmarks/disagg_benchmarks/disagg_performance_benchmark.sh
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Requirement: 2x GPUs.
|
||||||
|
|
||||||
|
|
||||||
|
# Model: meta-llama/Meta-Llama-3.1-8B-Instruct
|
||||||
|
# Query: 1024 input tokens, 6 output tokens, QPS 2/4/6/8, 100 requests
|
||||||
|
# Resource: 2x GPU
|
||||||
|
# Approaches:
|
||||||
|
# 2. Chunked prefill: 2 vllm instance with tp=4, equivalent to 1 tp=4 instance with QPS 4
|
||||||
|
# 3. Disaggregated prefill: 1 prefilling instance and 1 decoding instance
|
||||||
|
# Prefilling instance: max_output_token=1
|
||||||
|
# Decoding instance: force the input tokens be the same across requests to bypass prefilling
|
||||||
|
|
||||||
|
set -ex
|
||||||
|
|
||||||
|
kill_gpu_processes() {
|
||||||
|
# kill all processes on GPU.
|
||||||
|
pgrep pt_main_thread | xargs -r kill -9
|
||||||
|
pgrep python3 | xargs -r kill -9
|
||||||
|
# vLLM now names the process with VLLM prefix after https://github.com/vllm-project/vllm/pull/21445
|
||||||
|
pgrep VLLM | xargs -r kill -9
|
||||||
|
for port in 8000 8100 8200; do lsof -t -i:$port | xargs -r kill -9; done
|
||||||
|
sleep 1
|
||||||
|
}
|
||||||
|
|
||||||
|
wait_for_server() {
|
||||||
|
# wait for vllm server to start
|
||||||
|
# return 1 if vllm server crashes
|
||||||
|
local port=$1
|
||||||
|
timeout 1200 bash -c "
|
||||||
|
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||||
|
sleep 1
|
||||||
|
done" && return 0 || return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
launch_chunked_prefill() {
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
# disagg prefill
|
||||||
|
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||||
|
--port 8100 \
|
||||||
|
--max-model-len 10000 \
|
||||||
|
--enable-chunked-prefill \
|
||||||
|
--gpu-memory-utilization 0.6 &
|
||||||
|
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||||
|
--port 8200 \
|
||||||
|
--max-model-len 10000 \
|
||||||
|
--enable-chunked-prefill \
|
||||||
|
--gpu-memory-utilization 0.6 &
|
||||||
|
wait_for_server 8100
|
||||||
|
wait_for_server 8200
|
||||||
|
python3 round_robin_proxy.py &
|
||||||
|
sleep 1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
launch_disagg_prefill() {
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
# disagg prefill
|
||||||
|
CUDA_VISIBLE_DEVICES=0 vllm serve $model \
|
||||||
|
--port 8100 \
|
||||||
|
--max-model-len 10000 \
|
||||||
|
--gpu-memory-utilization 0.6 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=1 vllm serve $model \
|
||||||
|
--port 8200 \
|
||||||
|
--max-model-len 10000 \
|
||||||
|
--gpu-memory-utilization 0.6 \
|
||||||
|
--kv-transfer-config \
|
||||||
|
'{"kv_connector":"P2pNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2,"kv_buffer_size":5e9}' &
|
||||||
|
|
||||||
|
wait_for_server 8100
|
||||||
|
wait_for_server 8200
|
||||||
|
python3 disagg_prefill_proxy_server.py &
|
||||||
|
sleep 1
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
benchmark() {
|
||||||
|
results_folder="./results"
|
||||||
|
model="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
dataset_name="sonnet"
|
||||||
|
dataset_path="../sonnet_4x.txt"
|
||||||
|
num_prompts=100
|
||||||
|
qps=$1
|
||||||
|
prefix_len=50
|
||||||
|
input_len=1024
|
||||||
|
output_len=$2
|
||||||
|
tag=$3
|
||||||
|
|
||||||
|
vllm bench serve \
|
||||||
|
--backend vllm \
|
||||||
|
--model $model \
|
||||||
|
--dataset-name $dataset_name \
|
||||||
|
--dataset-path $dataset_path \
|
||||||
|
--sonnet-input-len $input_len \
|
||||||
|
--sonnet-output-len "$output_len" \
|
||||||
|
--sonnet-prefix-len $prefix_len \
|
||||||
|
--num-prompts $num_prompts \
|
||||||
|
--port 8000 \
|
||||||
|
--save-result \
|
||||||
|
--result-dir $results_folder \
|
||||||
|
--result-filename "$tag"-qps-"$qps".json \
|
||||||
|
--request-rate "$qps"
|
||||||
|
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
main() {
|
||||||
|
|
||||||
|
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||||
|
(which jq) || (apt-get -y install jq)
|
||||||
|
(which socat) || (apt-get -y install socat)
|
||||||
|
(which lsof) || (apt-get -y install lsof)
|
||||||
|
|
||||||
|
pip install quart httpx matplotlib aiohttp datasets
|
||||||
|
|
||||||
|
cd "$(dirname "$0")"
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
# create sonnet-4x.txt so that we can sample 2048 tokens for input
|
||||||
|
echo "" > sonnet_4x.txt
|
||||||
|
for _ in {1..4}
|
||||||
|
do
|
||||||
|
cat sonnet.txt >> sonnet_4x.txt
|
||||||
|
done
|
||||||
|
cd disagg_benchmarks
|
||||||
|
|
||||||
|
rm -rf results
|
||||||
|
mkdir results
|
||||||
|
|
||||||
|
default_output_len=6
|
||||||
|
|
||||||
|
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||||
|
|
||||||
|
launch_chunked_prefill
|
||||||
|
for qps in 2 4 6 8; do
|
||||||
|
benchmark $qps $default_output_len chunked_prefill
|
||||||
|
done
|
||||||
|
kill_gpu_processes
|
||||||
|
|
||||||
|
launch_disagg_prefill
|
||||||
|
for qps in 2 4 6 8; do
|
||||||
|
benchmark $qps $default_output_len disagg_prefill
|
||||||
|
done
|
||||||
|
kill_gpu_processes
|
||||||
|
|
||||||
|
python3 visualize_benchmark_results.py
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
main "$@"
|
||||||
260
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Normal file
260
benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from quart import Quart, Response, make_response, request
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
"""parse command line arguments"""
|
||||||
|
parser = argparse.ArgumentParser(description="vLLM P/D disaggregation proxy server")
|
||||||
|
|
||||||
|
# Add args
|
||||||
|
parser.add_argument(
|
||||||
|
"--timeout",
|
||||||
|
type=float,
|
||||||
|
default=6 * 60 * 60,
|
||||||
|
help="Timeout for backend service requests in seconds (default: 21600)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port",
|
||||||
|
type=int,
|
||||||
|
default=8000,
|
||||||
|
help="Port to run the server on (default: 8000)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8100",
|
||||||
|
help="Prefill service base URL (protocol + host[:port])",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-url",
|
||||||
|
type=str,
|
||||||
|
default="http://localhost:8200",
|
||||||
|
help="Decode service base URL (protocol + host[:port])",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-host",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Hostname or IP used by KV transfer (default: localhost)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prefill-kv-port",
|
||||||
|
type=int,
|
||||||
|
default=14579,
|
||||||
|
help="Prefill KV port (default: 14579)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--decode-kv-port",
|
||||||
|
type=int,
|
||||||
|
default=14580,
|
||||||
|
help="Decode KV port (default: 14580)",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""parse command line arguments"""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Initialize configuration using command line parameters
|
||||||
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=args.timeout)
|
||||||
|
PREFILL_SERVICE_URL = args.prefill_url
|
||||||
|
DECODE_SERVICE_URL = args.decode_url
|
||||||
|
PORT = args.port
|
||||||
|
|
||||||
|
PREFILL_KV_ADDR = f"{args.kv_host}:{args.prefill_kv_port}"
|
||||||
|
DECODE_KV_ADDR = f"{args.kv_host}:{args.decode_kv_port}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Proxy resolved KV addresses -> prefill: %s, decode: %s",
|
||||||
|
PREFILL_KV_ADDR,
|
||||||
|
DECODE_KV_ADDR,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = Quart(__name__)
|
||||||
|
|
||||||
|
# Attach the configuration object to the application instance so helper
|
||||||
|
# coroutines can read the resolved backend URLs and timeouts without using
|
||||||
|
# globals.
|
||||||
|
app.config.update(
|
||||||
|
{
|
||||||
|
"AIOHTTP_TIMEOUT": AIOHTTP_TIMEOUT,
|
||||||
|
"PREFILL_SERVICE_URL": PREFILL_SERVICE_URL,
|
||||||
|
"DECODE_SERVICE_URL": DECODE_SERVICE_URL,
|
||||||
|
"PREFILL_KV_ADDR": PREFILL_KV_ADDR,
|
||||||
|
"DECODE_KV_ADDR": DECODE_KV_ADDR,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_base_url(url: str) -> str:
|
||||||
|
"""Remove any trailing slash so path joins behave predictably."""
|
||||||
|
return url.rstrip("/")
|
||||||
|
|
||||||
|
def _get_host_port(url: str) -> str:
|
||||||
|
"""Return the hostname:port portion for logging and KV headers."""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
host = parsed.hostname or "localhost"
|
||||||
|
port = parsed.port
|
||||||
|
if port is None:
|
||||||
|
port = 80 if parsed.scheme == "http" else 443
|
||||||
|
return f"{host}:{port}"
|
||||||
|
|
||||||
|
PREFILL_BASE = _normalize_base_url(PREFILL_SERVICE_URL)
|
||||||
|
DECODE_BASE = _normalize_base_url(DECODE_SERVICE_URL)
|
||||||
|
KV_TARGET = _get_host_port(DECODE_SERVICE_URL)
|
||||||
|
|
||||||
|
def _build_headers(request_id: str) -> dict[str, str]:
|
||||||
|
"""Construct the headers expected by vLLM's P2P disagg connector."""
|
||||||
|
headers: dict[str, str] = {"X-Request-Id": request_id, "X-KV-Target": KV_TARGET}
|
||||||
|
api_key = os.environ.get("OPENAI_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def _run_prefill(
|
||||||
|
request_path: str,
|
||||||
|
payload: dict,
|
||||||
|
headers: dict[str, str],
|
||||||
|
request_id: str,
|
||||||
|
):
|
||||||
|
url = f"{PREFILL_BASE}{request_path}"
|
||||||
|
start_ts = time.perf_counter()
|
||||||
|
logger.info("[prefill] start request_id=%s url=%s", request_id, url)
|
||||||
|
try:
|
||||||
|
async with (
|
||||||
|
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||||
|
session.post(url=url, json=payload, headers=headers) as resp,
|
||||||
|
):
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Prefill backend error {resp.status}: {error_text}"
|
||||||
|
)
|
||||||
|
await resp.read()
|
||||||
|
logger.info(
|
||||||
|
"[prefill] done request_id=%s status=%s elapsed=%.2fs",
|
||||||
|
request_id,
|
||||||
|
resp.status,
|
||||||
|
time.perf_counter() - start_ts,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
raise RuntimeError(f"Prefill service timeout at {url}") from exc
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
raise RuntimeError(f"Prefill service unavailable at {url}") from exc
|
||||||
|
|
||||||
|
async def _stream_decode(
|
||||||
|
request_path: str,
|
||||||
|
payload: dict,
|
||||||
|
headers: dict[str, str],
|
||||||
|
request_id: str,
|
||||||
|
):
|
||||||
|
url = f"{DECODE_BASE}{request_path}"
|
||||||
|
# Stream tokens from the decode service once the prefill stage has
|
||||||
|
# materialized KV caches on the target workers.
|
||||||
|
logger.info("[decode] start request_id=%s url=%s", request_id, url)
|
||||||
|
try:
|
||||||
|
async with (
|
||||||
|
aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session,
|
||||||
|
session.post(url=url, json=payload, headers=headers) as resp,
|
||||||
|
):
|
||||||
|
if resp.status != 200:
|
||||||
|
error_text = await resp.text()
|
||||||
|
logger.error(
|
||||||
|
"Decode backend error %s - %s", resp.status, error_text
|
||||||
|
)
|
||||||
|
err_msg = (
|
||||||
|
'{"error": "Decode backend error ' + str(resp.status) + '"}'
|
||||||
|
)
|
||||||
|
yield err_msg.encode()
|
||||||
|
return
|
||||||
|
logger.info(
|
||||||
|
"[decode] streaming response request_id=%s status=%s",
|
||||||
|
request_id,
|
||||||
|
resp.status,
|
||||||
|
)
|
||||||
|
async for chunk_bytes in resp.content.iter_chunked(1024):
|
||||||
|
yield chunk_bytes
|
||||||
|
logger.info("[decode] finished streaming request_id=%s", request_id)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error("Decode service timeout at %s", url)
|
||||||
|
yield b'{"error": "Decode service timeout"}'
|
||||||
|
except aiohttp.ClientError as exc:
|
||||||
|
logger.error("Decode service error at %s: %s", url, exc)
|
||||||
|
yield b'{"error": "Decode service unavailable"}'
|
||||||
|
|
||||||
|
async def process_request():
|
||||||
|
"""Process a single request through prefill and decode stages"""
|
||||||
|
try:
|
||||||
|
original_request_data = await request.get_json()
|
||||||
|
|
||||||
|
# Create prefill request (max_tokens=1)
|
||||||
|
prefill_request = original_request_data.copy()
|
||||||
|
prefill_request["max_tokens"] = 1
|
||||||
|
if "max_completion_tokens" in prefill_request:
|
||||||
|
prefill_request["max_completion_tokens"] = 1
|
||||||
|
|
||||||
|
# Execute prefill stage
|
||||||
|
# The request id encodes both KV socket addresses so the backend can
|
||||||
|
# shuttle tensors directly via NCCL once the prefill response
|
||||||
|
# completes.
|
||||||
|
request_id = (
|
||||||
|
f"___prefill_addr_{PREFILL_KV_ADDR}___decode_addr_"
|
||||||
|
f"{DECODE_KV_ADDR}_{uuid.uuid4().hex}"
|
||||||
|
)
|
||||||
|
|
||||||
|
headers = _build_headers(request_id)
|
||||||
|
await _run_prefill(request.path, prefill_request, headers, request_id)
|
||||||
|
|
||||||
|
# Execute decode stage and stream response
|
||||||
|
# Pass the unmodified user request so the decode phase can continue
|
||||||
|
# sampling with the already-populated KV cache.
|
||||||
|
generator = _stream_decode(
|
||||||
|
request.path, original_request_data, headers, request_id
|
||||||
|
)
|
||||||
|
response = await make_response(generator)
|
||||||
|
response.timeout = None # Disable timeout for streaming response
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error processing request")
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Internal server error"}',
|
||||||
|
status=500,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.route("/v1/completions", methods=["POST"])
|
||||||
|
async def handle_request():
|
||||||
|
"""Handle incoming API requests with concurrency and rate limiting"""
|
||||||
|
try:
|
||||||
|
return await process_request()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.warning("Request cancelled")
|
||||||
|
return Response(
|
||||||
|
response=b'{"error": "Request cancelled"}',
|
||||||
|
status=503,
|
||||||
|
content_type="application/json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the Quart server with host can be set to 0.0.0.0
|
||||||
|
app.run(port=PORT)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
45
benchmarks/disagg_benchmarks/rate_limiter.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Token bucket rate limiter implementation"""
|
||||||
|
|
||||||
|
def __init__(self, rate_limit):
|
||||||
|
self.rate_limit = rate_limit # Requests per second
|
||||||
|
self.num_available_tokens = rate_limit # Available tokens
|
||||||
|
self.last_refill = time.monotonic() # Last token refill time
|
||||||
|
self.lock = asyncio.Lock() # Synchronization lock
|
||||||
|
|
||||||
|
async def acquire(self):
|
||||||
|
"""Acquire a token from the rate limiter"""
|
||||||
|
while True:
|
||||||
|
async with self.lock:
|
||||||
|
current_time = time.monotonic()
|
||||||
|
elapsed = current_time - self.last_refill
|
||||||
|
|
||||||
|
# Refill num_available_tokens if more than 1 second has passed
|
||||||
|
if elapsed > 1.0:
|
||||||
|
self.num_available_tokens = self.rate_limit
|
||||||
|
self.last_refill = current_time
|
||||||
|
|
||||||
|
# Check if num_available_tokens are available
|
||||||
|
if self.num_available_tokens > 0:
|
||||||
|
self.num_available_tokens -= 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Calculate wait time if no num_available_tokens available
|
||||||
|
wait_time = 1.0 - elapsed
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Enter async context manager - acquire token"""
|
||||||
|
await self.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_value, traceback):
|
||||||
|
"""Exit async context manager - no cleanup needed"""
|
||||||
|
pass
|
||||||
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
39
benchmarks/disagg_benchmarks/request_queue.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
|
||||||
|
class RequestQueue:
|
||||||
|
"""Request queue manager with concurrency control"""
|
||||||
|
|
||||||
|
def __init__(self, max_concurrent, max_queue_size):
|
||||||
|
# Maximum concurrent requests
|
||||||
|
self.max_concurrent = max_concurrent
|
||||||
|
self.max_queue_size = max_queue_size # Maximum queue size
|
||||||
|
# Concurrency control
|
||||||
|
self.semaphore = asyncio.Semaphore(max_concurrent)
|
||||||
|
self.queue = deque() # Request queue
|
||||||
|
self.queue_size = 0 # Current queue size
|
||||||
|
self.lock = asyncio.Lock() # Sync queue Lock
|
||||||
|
|
||||||
|
async def enqueue(self, task):
|
||||||
|
"""Add a request task to the queue"""
|
||||||
|
async with self.lock:
|
||||||
|
if self.queue_size >= self.max_queue_size:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.queue.append(task)
|
||||||
|
self.queue_size += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def process(self):
|
||||||
|
"""Process queued requests using semaphore for concurrency control"""
|
||||||
|
while True:
|
||||||
|
if self.queue:
|
||||||
|
async with self.semaphore, self.lock:
|
||||||
|
task = self.queue.popleft()
|
||||||
|
self.queue_size -= 1
|
||||||
|
await task
|
||||||
|
await asyncio.sleep(0.01) # Yield control to event loop
|
||||||
63
benchmarks/disagg_benchmarks/round_robin_proxy.py
Normal file
63
benchmarks/disagg_benchmarks/round_robin_proxy.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
|
||||||
|
class RoundRobinProxy:
|
||||||
|
def __init__(self, target_ports):
|
||||||
|
self.target_ports = target_ports
|
||||||
|
self.port_cycle = itertools.cycle(self.target_ports)
|
||||||
|
|
||||||
|
async def handle_request(self, request):
|
||||||
|
target_port = next(self.port_cycle)
|
||||||
|
target_url = f"http://localhost:{target_port}{request.path_qs}"
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
# Forward the request
|
||||||
|
async with session.request(
|
||||||
|
method=request.method,
|
||||||
|
url=target_url,
|
||||||
|
headers=request.headers,
|
||||||
|
data=request.content,
|
||||||
|
) as response:
|
||||||
|
# Start sending the response
|
||||||
|
resp = web.StreamResponse(
|
||||||
|
status=response.status, headers=response.headers
|
||||||
|
)
|
||||||
|
await resp.prepare(request)
|
||||||
|
|
||||||
|
# Stream the response content
|
||||||
|
async for chunk in response.content.iter_any():
|
||||||
|
await resp.write(chunk)
|
||||||
|
|
||||||
|
await resp.write_eof()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return web.Response(text=f"Error: {str(e)}", status=500)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
proxy = RoundRobinProxy([8100, 8200])
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_route("*", "/{path:.*}", proxy.handle_request)
|
||||||
|
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, "localhost", 8000)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
print("Proxy server started on http://localhost:8000")
|
||||||
|
|
||||||
|
# Keep the server running
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
47
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
Normal file
47
benchmarks/disagg_benchmarks/visualize_benchmark_results.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
data = []
|
||||||
|
for name in ["disagg_prefill", "chunked_prefill"]:
|
||||||
|
for qps in [2, 4, 6, 8]:
|
||||||
|
with open(f"results/{name}-qps-{qps}.json") as f:
|
||||||
|
x = json.load(f)
|
||||||
|
x["name"] = name
|
||||||
|
x["qps"] = qps
|
||||||
|
data.append(x)
|
||||||
|
|
||||||
|
df = pd.DataFrame.from_dict(data)
|
||||||
|
dis_df = df[df["name"] == "disagg_prefill"]
|
||||||
|
chu_df = df[df["name"] == "chunked_prefill"]
|
||||||
|
|
||||||
|
plt.style.use("bmh")
|
||||||
|
plt.rcParams["font.size"] = 20
|
||||||
|
|
||||||
|
for key in [
|
||||||
|
"mean_ttft_ms",
|
||||||
|
"median_ttft_ms",
|
||||||
|
"p99_ttft_ms",
|
||||||
|
"mean_itl_ms",
|
||||||
|
"median_itl_ms",
|
||||||
|
"p99_itl_ms",
|
||||||
|
]:
|
||||||
|
fig, ax = plt.subplots(figsize=(11, 7))
|
||||||
|
plt.plot(
|
||||||
|
dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4
|
||||||
|
)
|
||||||
|
plt.plot(
|
||||||
|
chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4
|
||||||
|
)
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
ax.set_xlabel("QPS")
|
||||||
|
ax.set_ylabel(key)
|
||||||
|
ax.set_ylim(bottom=0)
|
||||||
|
fig.savefig(f"results/{key}.png")
|
||||||
|
plt.close(fig)
|
||||||
310
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Normal file
310
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class bench_params_t:
|
||||||
|
num_tokens: int
|
||||||
|
hidden_size: int
|
||||||
|
add_residual: bool
|
||||||
|
dtype: torch.dtype
|
||||||
|
group_size: list[int]
|
||||||
|
|
||||||
|
def description(self):
|
||||||
|
return (
|
||||||
|
f"N {self.num_tokens} "
|
||||||
|
f"x D {self.hidden_size} "
|
||||||
|
f"x R {self.add_residual} "
|
||||||
|
f"x DT {self.dtype}"
|
||||||
|
f"x GS {self.group_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bench_params() -> list[bench_params_t]:
|
||||||
|
## Test Fixtures
|
||||||
|
NUM_TOKENS = [2**x for x in range(11)]
|
||||||
|
HIDDEN_SIZES = list(range(1024, 8129, 1024))
|
||||||
|
ADD_RESIDUAL = [True, False]
|
||||||
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
|
GROUP_SIZES = [[1, 64], [1, 128]]
|
||||||
|
|
||||||
|
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES, GROUP_SIZES)
|
||||||
|
bench_params = list(
|
||||||
|
map(lambda x: bench_params_t(x[0], x[1], x[2], x[3], x[4]), combinations)
|
||||||
|
)
|
||||||
|
return bench_params
|
||||||
|
|
||||||
|
|
||||||
|
# Reference impls
|
||||||
|
def unfused_int8_impl(
|
||||||
|
rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_size: list[int],
|
||||||
|
):
|
||||||
|
# Norm
|
||||||
|
torch_out = None
|
||||||
|
if residual is None:
|
||||||
|
torch_out = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
else:
|
||||||
|
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
|
||||||
|
|
||||||
|
|
||||||
|
def unfused_fp8_impl(
|
||||||
|
rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_size: list[int],
|
||||||
|
):
|
||||||
|
# Norm
|
||||||
|
torch_out = None
|
||||||
|
if residual is None:
|
||||||
|
torch_out = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
else:
|
||||||
|
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
torch_out, _ = ops.scaled_fp8_quant(torch_out)
|
||||||
|
|
||||||
|
|
||||||
|
def unfused_groupwise_fp8_impl(
|
||||||
|
rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_size: list[int],
|
||||||
|
):
|
||||||
|
# Norm
|
||||||
|
torch_out = None
|
||||||
|
if residual is None:
|
||||||
|
torch_out = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
else:
|
||||||
|
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
torch_out, _ = per_token_group_quant_fp8(
|
||||||
|
torch_out, group_size=group_size[1], use_ue8m0=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_impl(
|
||||||
|
rms_norm_layer: RMSNorm, # this stores the weights
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_size: list[int],
|
||||||
|
):
|
||||||
|
out, _ = ops.rms_norm_dynamic_per_token_quant(
|
||||||
|
x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_groupwise_impl(
|
||||||
|
rms_norm_layer: RMSNorm, # this stores the weights
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_size: list[int],
|
||||||
|
):
|
||||||
|
out, _ = ops.rms_norm_per_block_quant(
|
||||||
|
x,
|
||||||
|
rms_norm_layer.weight,
|
||||||
|
1e-6,
|
||||||
|
quant_dtype,
|
||||||
|
group_size,
|
||||||
|
residual=residual,
|
||||||
|
is_scale_transposed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Bench functions
|
||||||
|
def bench_fn(
|
||||||
|
rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
group_size: list[int],
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
fn: Callable,
|
||||||
|
description: str,
|
||||||
|
) -> TMeasurement:
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"rms_norm_layer": rms_norm_layer,
|
||||||
|
"x": x,
|
||||||
|
"residual": residual,
|
||||||
|
"quant_dtype": quant_dtype,
|
||||||
|
"group_size": group_size,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(rms_norm_layer, x, residual, quant_dtype, group_size)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
|
||||||
|
def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]:
|
||||||
|
# Make inputs
|
||||||
|
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
|
||||||
|
# Make weights
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
# Make inputs
|
||||||
|
scale = 1 / params.hidden_size
|
||||||
|
x = (
|
||||||
|
torch.randn(
|
||||||
|
params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda"
|
||||||
|
)
|
||||||
|
* scale
|
||||||
|
)
|
||||||
|
residual = (
|
||||||
|
(torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None
|
||||||
|
)
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
|
||||||
|
# unfused int8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
torch.int8,
|
||||||
|
params.group_size,
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
unfused_int8_impl,
|
||||||
|
"unfused_int8_impl",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# unfused fp8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
params.group_size,
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
unfused_fp8_impl,
|
||||||
|
"unfused_fp8_impl",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused int8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
torch.int8,
|
||||||
|
params.group_size,
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
fused_impl,
|
||||||
|
"fused_int8_impl",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused fp8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
params.group_size,
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
fused_impl,
|
||||||
|
"fused_fp8_impl",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# unfused groupwise fp8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
params.group_size,
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
unfused_groupwise_fp8_impl,
|
||||||
|
"unfused_groupwise_fp8_impl",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# fused groupwise fp8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(
|
||||||
|
layer,
|
||||||
|
x,
|
||||||
|
residual,
|
||||||
|
torch.float8_e4m3fn,
|
||||||
|
params.group_size,
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
fused_groupwise_impl,
|
||||||
|
"fused_groupwise_fp8_impl",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print_timers(timers)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
# launch bench
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
bench_params = get_bench_params()
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
for bp in tqdm(bench_params):
|
||||||
|
timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
|
||||||
|
print_timers(timers)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time())
|
||||||
|
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(timers, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
160
benchmarks/kernels/bench_block_fp8_gemm.py
Normal file
160
benchmarks/kernels/bench_block_fp8_gemm.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Disable DeepGEMM for this benchmark to use CUTLASS
|
||||||
|
os.environ["VLLM_USE_DEEP_GEMM"] = "0"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
W8A8BlockFp8LinearOp,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
GroupShape,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton as vllm_triton
|
||||||
|
|
||||||
|
assert current_platform.is_cuda(), (
|
||||||
|
"Only support benchmarking w8a8 block fp8 kernel on CUDA device."
|
||||||
|
)
|
||||||
|
|
||||||
|
# DeepSeek-V3 weight shapes
|
||||||
|
DEEPSEEK_V3_SHAPES = [
|
||||||
|
(512 + 64, 7168),
|
||||||
|
(2112, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(7168, 16384),
|
||||||
|
(7168, 18432),
|
||||||
|
(18432 * 2, 7168),
|
||||||
|
(24576, 1536),
|
||||||
|
(12288, 7168),
|
||||||
|
(4096, 7168),
|
||||||
|
(7168, 2048),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
|
||||||
|
"""Build runner function for w8a8 block fp8 matmul."""
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
# Create random input tensor (bfloat16, will be quantized by W8A8BlockFp8LinearOp)
|
||||||
|
A_ref = (torch.rand(M, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||||
|
|
||||||
|
# Create quantized weight tensor
|
||||||
|
B_ref = (torch.rand(N, K, dtype=torch.bfloat16, device=device) - 0.5) * 2 * fp8_max
|
||||||
|
B = B_ref.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
# Create weight scales
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
|
Bs = (
|
||||||
|
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device=device)
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create W8A8BlockFp8LinearOp instance
|
||||||
|
weight_group_shape = GroupShape(block_n, block_k)
|
||||||
|
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization
|
||||||
|
|
||||||
|
linear_op = W8A8BlockFp8LinearOp(
|
||||||
|
weight_group_shape=weight_group_shape,
|
||||||
|
act_quant_group_shape=act_quant_group_shape,
|
||||||
|
cutlass_block_fp8_supported=use_cutlass,
|
||||||
|
use_aiter_and_is_supported=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return linear_op.apply(
|
||||||
|
input=A_ref,
|
||||||
|
weight=B,
|
||||||
|
weight_scale=Bs,
|
||||||
|
input_scale=None,
|
||||||
|
bias=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
# Determine available providers
|
||||||
|
available_providers = ["torch-bf16", "w8a8-block-fp8-triton"]
|
||||||
|
plot_title = "BF16 vs W8A8 Block FP8 GEMMs"
|
||||||
|
|
||||||
|
if CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||||
|
available_providers.append("w8a8-block-fp8-cutlass")
|
||||||
|
|
||||||
|
|
||||||
|
@vllm_triton.testing.perf_report(
|
||||||
|
vllm_triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=available_providers,
|
||||||
|
line_names=available_providers,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs W8A8 Block FP8 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark_tflops(batch_size, provider, N, K, block_size=(128, 128)):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
a = torch.randn((M, K), device=device, dtype=torch.bfloat16)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=torch.bfloat16)
|
||||||
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||||
|
)
|
||||||
|
elif provider == "w8a8-block-fp8-triton":
|
||||||
|
run_w8a8_triton = build_w8a8_block_fp8_runner(
|
||||||
|
M, N, K, block_size, device, use_cutlass=False
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_w8a8_triton(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
elif provider == "w8a8-block-fp8-cutlass":
|
||||||
|
run_w8a8_cutlass = build_w8a8_block_fp8_runner(
|
||||||
|
M, N, K, block_size, device, use_cutlass=True
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = vllm_triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_w8a8_cutlass(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
block_size = (128, 128)
|
||||||
|
|
||||||
|
for N, K in DEEPSEEK_V3_SHAPES:
|
||||||
|
print(f"\nBenchmarking DeepSeek-V3, N={N} K={K}")
|
||||||
|
|
||||||
|
print(f"TFLOP/s comparison (block_size={block_size}):")
|
||||||
|
benchmark_tflops.run(
|
||||||
|
print_data=True,
|
||||||
|
# show_plots=False,
|
||||||
|
# save_path=f"bench_w8a8_block_fp8_tflops_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
block_size=block_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nBenchmark finished!")
|
||||||
159
benchmarks/kernels/bench_fp8_gemm.py
Normal file
159
benchmarks/kernels/bench_fp8_gemm.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||||
|
from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"fp8-tensor-w-token-a": dict(
|
||||||
|
w="tensor", a="token", no_a_quant=False, enabled=False
|
||||||
|
),
|
||||||
|
"fp8-tensor-w-tensor-a": dict(
|
||||||
|
w="tensor", a="tensor", no_a_quant=False, enabled=True
|
||||||
|
),
|
||||||
|
"fp8-channel-w-token-a": dict(
|
||||||
|
w="channel", a="token", no_a_quant=False, enabled=True
|
||||||
|
),
|
||||||
|
"fp8-channel-w-tensor-a": dict(
|
||||||
|
w="channel", a="tensor", no_a_quant=False, enabled=False
|
||||||
|
),
|
||||||
|
"fp8-tensor-w-token-a-noquant": dict(
|
||||||
|
w="tensor", a="token", no_a_quant=True, enabled=False
|
||||||
|
),
|
||||||
|
"fp8-tensor-w-tensor-a-noquant": dict(
|
||||||
|
w="tensor", a="tensor", no_a_quant=True, enabled=True
|
||||||
|
),
|
||||||
|
"fp8-channel-w-token-a-noquant": dict(
|
||||||
|
w="channel", a="token", no_a_quant=True, enabled=True
|
||||||
|
),
|
||||||
|
"fp8-channel-w-tensor-a-noquant": dict(
|
||||||
|
w="channel", a="tensor", no_a_quant=True, enabled=False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str):
|
||||||
|
if w_type == "tensor":
|
||||||
|
scale_b = torch.ones(1, device=device, dtype=torch.float32)
|
||||||
|
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
|
||||||
|
else:
|
||||||
|
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True)
|
||||||
|
return b_fp8.t(), scale_b_fp8
|
||||||
|
|
||||||
|
|
||||||
|
def build_fp8_runner(cfg, a, b, dtype, device):
|
||||||
|
b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device)
|
||||||
|
|
||||||
|
scale_a_const = (
|
||||||
|
torch.ones(1, device=device, dtype=torch.float32)
|
||||||
|
if cfg["a"] == "tensor"
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
if cfg["a"] == "tensor":
|
||||||
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const)
|
||||||
|
else:
|
||||||
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
if cfg["a"] == "tensor":
|
||||||
|
|
||||||
|
def run():
|
||||||
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const)
|
||||||
|
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run():
|
||||||
|
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||||
|
return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs FP8 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_fp8_runner(cfg, a, b, dtype, device)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=f"bench_fp8_res_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
169
benchmarks/kernels/bench_int8_gemm.py
Normal file
169
benchmarks/kernels/bench_int8_gemm.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
|
||||||
|
from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"int8-tensor-w-token-a": dict(
|
||||||
|
w="tensor", a="token", no_a_quant=False, enabled=False
|
||||||
|
),
|
||||||
|
"int8-tensor-w-tensor-a": dict(
|
||||||
|
w="tensor", a="tensor", no_a_quant=False, enabled=True
|
||||||
|
),
|
||||||
|
"int8-channel-w-token-a": dict(
|
||||||
|
w="channel", a="token", no_a_quant=False, enabled=True
|
||||||
|
),
|
||||||
|
"int8-channel-w-tensor-a": dict(
|
||||||
|
w="channel", a="tensor", no_a_quant=False, enabled=False
|
||||||
|
),
|
||||||
|
"int8-tensor-w-token-a-noquant": dict(
|
||||||
|
w="tensor", a="token", no_a_quant=True, enabled=False
|
||||||
|
),
|
||||||
|
"int8-tensor-w-tensor-a-noquant": dict(
|
||||||
|
w="tensor", a="tensor", no_a_quant=True, enabled=True
|
||||||
|
),
|
||||||
|
"int8-channel-w-token-a-noquant": dict(
|
||||||
|
w="channel", a="token", no_a_quant=True, enabled=True
|
||||||
|
),
|
||||||
|
"int8-channel-w-tensor-a-noquant": dict(
|
||||||
|
w="channel", a="tensor", no_a_quant=True, enabled=False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight(b, w_type, device):
|
||||||
|
if w_type == "tensor":
|
||||||
|
scale_b = torch.ones(1, device=device, dtype=torch.float32)
|
||||||
|
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b)
|
||||||
|
assert scale_b_int8.numel() == 1
|
||||||
|
else: # channel
|
||||||
|
b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b)
|
||||||
|
assert scale_b_int8.numel() == b.shape[0]
|
||||||
|
return b_int8.t(), scale_b_int8
|
||||||
|
|
||||||
|
|
||||||
|
def build_int8_runner(cfg, a, b, dtype, device):
|
||||||
|
# quant before running the kernel
|
||||||
|
b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device)
|
||||||
|
|
||||||
|
scale_a_const = None
|
||||||
|
if cfg["a"] == "tensor":
|
||||||
|
scale_a_const = torch.ones(1, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# no quant, create activation ahead
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
if cfg["a"] == "tensor":
|
||||||
|
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const)
|
||||||
|
else: # token
|
||||||
|
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
|
||||||
|
|
||||||
|
def run_quant():
|
||||||
|
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
|
||||||
|
|
||||||
|
return run_quant
|
||||||
|
|
||||||
|
# dynamic quant, create activation inside
|
||||||
|
if cfg["a"] == "tensor":
|
||||||
|
|
||||||
|
def run_quant():
|
||||||
|
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const)
|
||||||
|
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
|
||||||
|
|
||||||
|
else: # token
|
||||||
|
|
||||||
|
def run_quant():
|
||||||
|
a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a)
|
||||||
|
return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype)
|
||||||
|
|
||||||
|
return run_quant
|
||||||
|
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=[k for k in _enabled],
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs INT8 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_int8_runner(cfg, a, b, dtype, device)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
KN_model_names = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
KN_model_names.append(KN)
|
||||||
|
return KN_model_names
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
help="List of models to benchmark",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp-sizes",
|
||||||
|
nargs="+",
|
||||||
|
type=int,
|
||||||
|
default=[1],
|
||||||
|
help="List of tensor parallel sizes",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=f"bench_int8_res_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
191
benchmarks/kernels/bench_mxfp4_qutlass.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||||
|
# All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm._custom_ops import fusedQuantizeMx, matmul_mxf4_bf16_tn
|
||||||
|
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"mxfp4": dict(no_a_quant=False, enabled=True),
|
||||||
|
"mxfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||||
|
return (
|
||||||
|
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||||
|
* group_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight_mxfp4(
|
||||||
|
b: torch.Tensor, forward_hadamard_matrix: torch.Tensor, device: str
|
||||||
|
):
|
||||||
|
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeMx(
|
||||||
|
b, forward_hadamard_matrix, method="abs_max"
|
||||||
|
)
|
||||||
|
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton")
|
||||||
|
return weight_hf_e2m1, weight_hf_scale_block
|
||||||
|
|
||||||
|
|
||||||
|
def build_mxfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device):
|
||||||
|
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_mxfp4(
|
||||||
|
b, forward_hadamard_matrix, device
|
||||||
|
)
|
||||||
|
alpha = torch.tensor([1.0], device="cuda")
|
||||||
|
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
# Pre-quantize activation
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||||
|
a, forward_hadamard_matrix, method="abs_max"
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return matmul_mxf4_bf16_tn(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
# Quantize activation on-the-fly
|
||||||
|
def run():
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeMx(
|
||||||
|
a, forward_hadamard_matrix, method="abs_max"
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton")
|
||||||
|
return matmul_mxf4_bf16_tn(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048,
|
||||||
|
4096,
|
||||||
|
8192,
|
||||||
|
16384,
|
||||||
|
24576,
|
||||||
|
32768,
|
||||||
|
],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs MXFP4 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K, had_size):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_mxfp4_runner(
|
||||||
|
cfg, a, b, forward_hadamard_matrix, dtype, device
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
for had_size in [32, 64, 128]:
|
||||||
|
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs MXFP4 GEMMs TFLOP/s:")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=f"bench_mxfp4_res_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
had_size=had_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
198
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
198
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
if not current_platform.has_device_capability(100):
|
||||||
|
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||||
|
|
||||||
|
|
||||||
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||||
|
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||||
|
"fbgemm-nvfp4": dict(fbgemm=True, no_a_quant=False, enabled=True),
|
||||||
|
"fbgemm-nvfp4-noquant": dict(fbgemm=True, no_a_quant=True, enabled=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
_needs_fbgemm = any(
|
||||||
|
v.get("fbgemm", False) for v in PROVIDER_CFGS.values() if v.get("enabled", False)
|
||||||
|
)
|
||||||
|
if _needs_fbgemm:
|
||||||
|
try:
|
||||||
|
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
|
||||||
|
triton_scale_nvfp4_quant,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"WARNING: FBGEMM providers are enabled but fbgemm_gpu is not installed. "
|
||||||
|
"These providers will be skipped. Please install fbgemm_gpu with: "
|
||||||
|
"'pip install fbgemm-gpu-genai' to run them."
|
||||||
|
)
|
||||||
|
# Disable FBGEMM providers so the benchmark can run.
|
||||||
|
for cfg in PROVIDER_CFGS.values():
|
||||||
|
if cfg.get("fbgemm"):
|
||||||
|
cfg["enabled"] = False
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight_nvfp4(b: torch.Tensor, device: str, cfg):
|
||||||
|
# Compute global scale for weight
|
||||||
|
b_amax = torch.abs(b).max().to(torch.float32)
|
||||||
|
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||||
|
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||||
|
b_fp4, scale_b_fp4 = triton_scale_nvfp4_quant(b, b_global_scale)
|
||||||
|
else:
|
||||||
|
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||||
|
return b_fp4, scale_b_fp4, b_global_scale
|
||||||
|
|
||||||
|
|
||||||
|
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||||
|
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device, cfg)
|
||||||
|
|
||||||
|
# Compute global scale for activation
|
||||||
|
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||||
|
a_amax = torch.abs(a).max().to(torch.float32)
|
||||||
|
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||||
|
|
||||||
|
# Alpha for the GEMM operation
|
||||||
|
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||||
|
if "fbgemm" in cfg and cfg["fbgemm"]:
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return torch.ops.fbgemm.f4f4bf16(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
scale_a_fp4,
|
||||||
|
scale_b_fp4,
|
||||||
|
global_scale=alpha,
|
||||||
|
use_mx=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
else:
|
||||||
|
|
||||||
|
def run():
|
||||||
|
a_fp4, scale_a_fp4 = triton_scale_nvfp4_quant(a, a_global_scale)
|
||||||
|
return torch.ops.fbgemm.f4f4bf16(
|
||||||
|
a_fp4,
|
||||||
|
b_fp4,
|
||||||
|
scale_a_fp4,
|
||||||
|
scale_b_fp4,
|
||||||
|
global_scale=alpha,
|
||||||
|
use_mx=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
# Pre-quantize activation
|
||||||
|
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return ops.cutlass_scaled_fp4_mm(
|
||||||
|
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
# Quantize activation on-the-fly
|
||||||
|
def run():
|
||||||
|
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||||
|
return ops.cutlass_scaled_fp4_mm(
|
||||||
|
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs NVFP4 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||||
|
save_dir = f"bench_nvfp4_res_n{N}_k{K}"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=save_dir,
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
207
benchmarks/kernels/bench_nvfp4_qutlass.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 Roberto L. Castro (Roberto.LopezCastro@ist.ac.at).
|
||||||
|
# All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.transform.utils.hadamard import deterministic_hadamard_matrix
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops # use existing nvfp4 gemm in vllm
|
||||||
|
from vllm._custom_ops import fusedQuantizeNv
|
||||||
|
from vllm.model_executor.layers.quantization.qutlass_utils import to_blocked
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
PROVIDER_CFGS = {
|
||||||
|
"torch-bf16": dict(enabled=True),
|
||||||
|
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||||
|
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||||
|
}
|
||||||
|
|
||||||
|
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_hadamard_matrix(group_size: int, dtype: torch.dtype, device: torch.device):
|
||||||
|
return (
|
||||||
|
deterministic_hadamard_matrix(group_size, dtype=dtype, device=device)
|
||||||
|
* group_size**-0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _quant_weight_nvfp4(
|
||||||
|
b: torch.Tensor,
|
||||||
|
forward_hadamard_matrix: torch.Tensor,
|
||||||
|
global_scale: torch.Tensor,
|
||||||
|
device: str,
|
||||||
|
M: int,
|
||||||
|
N: int,
|
||||||
|
K: int,
|
||||||
|
):
|
||||||
|
weight_hf_e2m1, weight_hf_e8m0 = fusedQuantizeNv(
|
||||||
|
b, forward_hadamard_matrix, global_scale
|
||||||
|
)
|
||||||
|
weight_hf_scale_block = to_blocked(weight_hf_e8m0, backend="triton").view(
|
||||||
|
-1, K // 16
|
||||||
|
)
|
||||||
|
return weight_hf_e2m1, weight_hf_scale_block
|
||||||
|
|
||||||
|
|
||||||
|
def build_nvfp4_runner(cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K):
|
||||||
|
alpha = torch.tensor([1.0], device="cuda")
|
||||||
|
global_scale = torch.tensor([1.0], device="cuda")
|
||||||
|
weight_hf_e2m1, weight_hf_scale_block = _quant_weight_nvfp4(
|
||||||
|
b, forward_hadamard_matrix, global_scale, device, M, N, K
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg["no_a_quant"]:
|
||||||
|
# Pre-quantize activation
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||||
|
a, forward_hadamard_matrix, global_scale
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||||
|
-1, K // 16
|
||||||
|
)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
return ops.cutlass_scaled_fp4_mm(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
# Quantize activation on-the-fly
|
||||||
|
def run():
|
||||||
|
input_hf_e2m1, input_hf_e8m0 = fusedQuantizeNv(
|
||||||
|
a, forward_hadamard_matrix, global_scale
|
||||||
|
)
|
||||||
|
input_hf_scale_block = to_blocked(input_hf_e8m0, backend="triton").view(
|
||||||
|
-1, K // 16
|
||||||
|
)
|
||||||
|
return ops.cutlass_scaled_fp4_mm(
|
||||||
|
input_hf_e2m1,
|
||||||
|
weight_hf_e2m1,
|
||||||
|
input_hf_scale_block,
|
||||||
|
weight_hf_scale_block,
|
||||||
|
alpha,
|
||||||
|
torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size"],
|
||||||
|
x_vals=[
|
||||||
|
1,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
64,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
2048,
|
||||||
|
4096,
|
||||||
|
8192,
|
||||||
|
16384,
|
||||||
|
24576,
|
||||||
|
32768,
|
||||||
|
],
|
||||||
|
x_log=False,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=_enabled,
|
||||||
|
line_names=_enabled,
|
||||||
|
ylabel="TFLOP/s (larger is better)",
|
||||||
|
plot_name="BF16 vs NVFP4 GEMMs",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(batch_size, provider, N, K, had_size):
|
||||||
|
M = batch_size
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||||
|
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||||
|
forward_hadamard_matrix = get_hadamard_matrix(had_size, dtype, device)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "torch-bf16":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: torch.nn.functional.linear(a, b), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = PROVIDER_CFGS[provider]
|
||||||
|
run_quant = build_nvfp4_runner(
|
||||||
|
cfg, a, b, forward_hadamard_matrix, dtype, device, M, N, K
|
||||||
|
)
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
lambda: run_quant(), rep=200, quantiles=quantiles
|
||||||
|
)
|
||||||
|
|
||||||
|
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||||
|
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_shapes(args):
|
||||||
|
out = []
|
||||||
|
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||||
|
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||||
|
KN[tp_dim] //= tp_size
|
||||||
|
KN.append(model)
|
||||||
|
out.append(KN)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=["meta-llama/Llama-3.3-70B-Instruct"],
|
||||||
|
choices=list(WEIGHT_SHAPES.keys()),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for K, N, model in prepare_shapes(args):
|
||||||
|
for had_size in [16, 32, 64, 128]:
|
||||||
|
print(f"{model}, N={N} K={K}, HAD={had_size}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||||
|
benchmark.run(
|
||||||
|
print_data=True,
|
||||||
|
show_plots=True,
|
||||||
|
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
had_size=had_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Benchmark finished!")
|
||||||
270
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
270
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import itertools
|
||||||
|
from collections.abc import Callable
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
|
|
||||||
|
def with_triton_mode(fn):
|
||||||
|
"""Temporarily force the Triton fallback path"""
|
||||||
|
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(luka): use standalone_compile utility
|
||||||
|
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||||
|
def inner(*args):
|
||||||
|
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||||
|
return fn(*args)
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
def bench_compile(fn: Callable):
|
||||||
|
# recompile for different shapes
|
||||||
|
fwd = torch.compile(fn, fullgraph=True, dynamic=False)
|
||||||
|
|
||||||
|
# First dim is explicitly dynamic to simulate vLLM usage
|
||||||
|
return with_dyn_arg(fwd, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
torch._dynamo.config.recompile_limit = 8888
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(
|
||||||
|
batch_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""Calculate the difference between Inductor and CUDA implementations."""
|
||||||
|
device = torch.device("cuda")
|
||||||
|
x = torch.randn((batch_size, hidden_size), dtype=dtype, device=device)
|
||||||
|
|
||||||
|
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=False)
|
||||||
|
|
||||||
|
torch_out, torch_scale = bench_compile(quant_fp8.forward_native)(x)
|
||||||
|
torch_eager_out, torch_eager_scale = quant_fp8.forward_native(x)
|
||||||
|
cuda_out, cuda_scale = quant_fp8.forward_cuda(x)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.testing.assert_close(
|
||||||
|
cuda_out.to(torch.float32),
|
||||||
|
torch_out.to(torch.float32),
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
cuda_out.to(torch.float32),
|
||||||
|
torch_eager_out.to(torch.float32),
|
||||||
|
rtol=1e-3,
|
||||||
|
atol=1e-5,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(cuda_scale, torch_eager_scale, rtol=1e-3, atol=1e-5)
|
||||||
|
print("✅ All implementations match")
|
||||||
|
except AssertionError as e:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_quantization(
|
||||||
|
batch_size,
|
||||||
|
hidden_size,
|
||||||
|
provider,
|
||||||
|
group_shape: GroupShape,
|
||||||
|
col_major: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, hidden_size, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
quant_fp8 = QuantFP8(False, group_shape, column_major_scales=col_major)
|
||||||
|
|
||||||
|
if provider == "torch":
|
||||||
|
fn = lambda: bench_compile(quant_fp8.forward_native)(x.clone())
|
||||||
|
elif provider == "cuda":
|
||||||
|
fn = lambda: quant_fp8.forward_cuda(x.clone())
|
||||||
|
elif provider == "triton":
|
||||||
|
if not group_shape.is_per_group():
|
||||||
|
# Triton only supported for per-group
|
||||||
|
return 0, 0, 0
|
||||||
|
|
||||||
|
fn = lambda: with_triton_mode(quant_fp8.forward_cuda)(x.clone())
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(luka) extract to utils
|
||||||
|
def compute_geomean_speedups(
|
||||||
|
df: pd.DataFrame,
|
||||||
|
baseline_col: str,
|
||||||
|
speedup_cols: list[str],
|
||||||
|
groupby_cols: list[str] | None = None,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Compute geometric mean speedups over a baseline column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Input dataframe
|
||||||
|
baseline_col: Column to use as baseline
|
||||||
|
speedup_cols: Columns to compute speedups for
|
||||||
|
groupby_cols: Columns to group by. If None, compute over entire df.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame with geometric mean speedups
|
||||||
|
"""
|
||||||
|
from scipy.stats import gmean
|
||||||
|
|
||||||
|
def geo_speedup(group: pd.DataFrame) -> pd.Series:
|
||||||
|
ratios = {
|
||||||
|
col: (group[baseline_col] / group[col]).values for col in speedup_cols
|
||||||
|
}
|
||||||
|
return pd.Series({col: gmean(vals) for col, vals in ratios.items()})
|
||||||
|
|
||||||
|
if groupby_cols is None:
|
||||||
|
result = geo_speedup(df).to_frame().T
|
||||||
|
else:
|
||||||
|
result = (
|
||||||
|
df.groupby(groupby_cols)
|
||||||
|
.apply(geo_speedup, include_groups=False)
|
||||||
|
.reset_index()
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the various implementations of QuantFP8 (dynamic-only)"
|
||||||
|
)
|
||||||
|
parser.add_argument("-c", "--check", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden-sizes",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[896, 1024, 2048, 4096, 7168],
|
||||||
|
help="Hidden sizes to benchmark",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=[1, 16, 128, 512, 1024],
|
||||||
|
help="Batch sizes to benchmark",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-sizes",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=None,
|
||||||
|
help="Group sizes for GroupShape(1,N) to benchmark. "
|
||||||
|
"Use 0 for PER_TENSOR, -1 for PER_TOKEN (default: 0,-1,64,128)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-column-major",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable column-major scales testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args
|
||||||
|
|
||||||
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||||
|
|
||||||
|
hidden_sizes = args.hidden_sizes
|
||||||
|
batch_sizes = args.batch_sizes
|
||||||
|
|
||||||
|
if args.group_sizes is not None:
|
||||||
|
group_shapes = []
|
||||||
|
for size in args.group_sizes:
|
||||||
|
if size == 0:
|
||||||
|
group_shapes.append(GroupShape.PER_TENSOR)
|
||||||
|
elif size == -1:
|
||||||
|
group_shapes.append(GroupShape.PER_TOKEN)
|
||||||
|
else:
|
||||||
|
group_shapes.append(GroupShape(1, size))
|
||||||
|
else:
|
||||||
|
group_shapes = [
|
||||||
|
GroupShape.PER_TENSOR,
|
||||||
|
GroupShape.PER_TOKEN,
|
||||||
|
GroupShape(1, 64),
|
||||||
|
GroupShape(1, 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
column_major_scales = [False] if args.no_column_major else [True, False]
|
||||||
|
|
||||||
|
config_gen = itertools.product(
|
||||||
|
group_shapes,
|
||||||
|
column_major_scales,
|
||||||
|
batch_sizes,
|
||||||
|
hidden_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# filter out column-major scales for non-group, reverse order
|
||||||
|
configs.extend(c[::-1] for c in config_gen if (c[0].is_per_group() or not c[1]))
|
||||||
|
|
||||||
|
print(f"Running {len(configs)} configurations:")
|
||||||
|
print(f" Hidden sizes: {hidden_sizes}")
|
||||||
|
print(f" Batch sizes: {batch_sizes}")
|
||||||
|
print(f" Group shapes: {[str(g) for g in group_shapes]}")
|
||||||
|
print(f" Column major scales: {column_major_scales}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if args.check:
|
||||||
|
for group_shape in group_shapes:
|
||||||
|
group_size = group_shape[1]
|
||||||
|
print(f"{group_size=}")
|
||||||
|
calculate_diff(
|
||||||
|
batch_size=4, hidden_size=4096, group_shape=group_shape, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark = triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["hidden_size", "batch_size", "col_major", "group_shape"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["torch", "cuda", "triton"],
|
||||||
|
line_names=["Torch (Compiled)", "CUDA", "Triton"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("black", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name="QuantFP8 performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)(benchmark_quantization)
|
||||||
|
|
||||||
|
df = benchmark.run(print_data=True, dtype=dtype, return_df=True)
|
||||||
|
|
||||||
|
# Print geomean speedups
|
||||||
|
geo_table_grouped = compute_geomean_speedups(
|
||||||
|
df,
|
||||||
|
baseline_col="Torch (Compiled)",
|
||||||
|
speedup_cols=["CUDA", "Triton"],
|
||||||
|
groupby_cols=["col_major", "group_shape"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Speedup over Torch (Compiled)")
|
||||||
|
print(geo_table_grouped.to_string(index=False))
|
||||||
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
244
benchmarks/kernels/benchmark_2d_silu_mul_fp8_quant.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import product
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
_per_token_group_quant_fp8_colmajor,
|
||||||
|
silu_mul_per_token_group_quant_fp8_colmajor,
|
||||||
|
)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
|
from .utils import ArgPool, Bench, CudaGraphBenchParams
|
||||||
|
|
||||||
|
GROUP_SIZE = 128
|
||||||
|
FLOAT8_T = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
|
def print_timers(timers: list[TMeasurement], cuda_graph_nops: int):
|
||||||
|
print(
|
||||||
|
f"Note : The timings reported above is for {cuda_graph_nops} "
|
||||||
|
"consecutive invocations of the benchmarking functions. "
|
||||||
|
f"Please divide by {cuda_graph_nops} for single invocation "
|
||||||
|
"timings."
|
||||||
|
)
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
class ImplType(Enum):
|
||||||
|
SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR = 1
|
||||||
|
REFERENCE = 2
|
||||||
|
|
||||||
|
def get_impl(self):
|
||||||
|
if self == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||||
|
return silu_mul_per_token_group_quant_fp8_colmajor
|
||||||
|
elif self == ImplType.REFERENCE:
|
||||||
|
return reference
|
||||||
|
raise ValueError(f"Unrecognized ImplType {self}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkTensors:
|
||||||
|
input: torch.Tensor
|
||||||
|
output: torch.Tensor
|
||||||
|
|
||||||
|
# Reference act output tensor
|
||||||
|
ref_act_out: torch.Tensor
|
||||||
|
ref_quant_out: torch.Tensor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make(T: int, N: int) -> "BenchmarkTensors":
|
||||||
|
assert T % GROUP_SIZE == 0
|
||||||
|
assert N % (GROUP_SIZE * 2) == 0
|
||||||
|
|
||||||
|
input = torch.rand((T, N), dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
# silu_mul_per_token_group_quant_fp8_colmajor output.
|
||||||
|
output = torch.rand((T, N // 2), dtype=torch.bfloat16, device="cuda").to(
|
||||||
|
FLOAT8_T
|
||||||
|
)
|
||||||
|
|
||||||
|
# reference output.
|
||||||
|
ref_act_out = torch.empty((T, N // 2), dtype=torch.bfloat16, device="cuda")
|
||||||
|
ref_quant_out = torch.empty(
|
||||||
|
(T, N // 2), dtype=torch.bfloat16, device="cuda"
|
||||||
|
).to(FLOAT8_T)
|
||||||
|
|
||||||
|
return BenchmarkTensors(
|
||||||
|
input=input,
|
||||||
|
output=output,
|
||||||
|
ref_act_out=ref_act_out,
|
||||||
|
ref_quant_out=ref_quant_out,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def T(self):
|
||||||
|
return self.input.size(0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def N(self):
|
||||||
|
return self.input.size(1)
|
||||||
|
|
||||||
|
def make_impl_kwargs(self, impl_type: ImplType) -> dict[str, Any]:
|
||||||
|
if impl_type == ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR:
|
||||||
|
return {
|
||||||
|
"input": self.input,
|
||||||
|
"output": self.output,
|
||||||
|
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||||
|
}
|
||||||
|
elif impl_type == ImplType.REFERENCE:
|
||||||
|
return {
|
||||||
|
"input": self.input,
|
||||||
|
"act_out": self.ref_act_out,
|
||||||
|
"quant_out": self.ref_quant_out,
|
||||||
|
"use_ue8m0": is_deep_gemm_e8m0_used(),
|
||||||
|
}
|
||||||
|
raise ValueError(f"Unrecognized impl_type {impl_type}")
|
||||||
|
|
||||||
|
|
||||||
|
def reference_quant(x: torch.Tensor, quant_out: torch.Tensor, use_ue8m0: bool):
|
||||||
|
"""
|
||||||
|
Reference triton quant kernel from,
|
||||||
|
vllm.model_executor.layers.quantization.utils.fp8_utils
|
||||||
|
"""
|
||||||
|
assert quant_out.size() == x.size()
|
||||||
|
# Allocate the scale tensor column-major format.
|
||||||
|
shape = (x.shape[-1] // GROUP_SIZE,) + x.shape[:-1]
|
||||||
|
x_q = quant_out
|
||||||
|
x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2)
|
||||||
|
|
||||||
|
M = x.numel() // GROUP_SIZE
|
||||||
|
N = GROUP_SIZE
|
||||||
|
BLOCK = triton.next_power_of_2(N)
|
||||||
|
# heuristics for number of warps
|
||||||
|
num_warps = min(max(BLOCK // 256, 1), 8)
|
||||||
|
num_stages = 1
|
||||||
|
|
||||||
|
finfo = torch.finfo(FLOAT8_T)
|
||||||
|
fp8_min = finfo.min
|
||||||
|
fp8_max = finfo.max
|
||||||
|
|
||||||
|
_per_token_group_quant_fp8_colmajor[(M,)](
|
||||||
|
x,
|
||||||
|
x_q,
|
||||||
|
x_s,
|
||||||
|
GROUP_SIZE,
|
||||||
|
x.shape[1],
|
||||||
|
x.stride(0),
|
||||||
|
x_s.stride(1),
|
||||||
|
eps=1e-10,
|
||||||
|
fp8_min=fp8_min,
|
||||||
|
fp8_max=fp8_max,
|
||||||
|
use_ue8m0=use_ue8m0,
|
||||||
|
BLOCK=BLOCK,
|
||||||
|
num_warps=num_warps,
|
||||||
|
num_stages=num_stages,
|
||||||
|
)
|
||||||
|
return x_q, x_s
|
||||||
|
|
||||||
|
|
||||||
|
def reference(
|
||||||
|
input: torch.Tensor,
|
||||||
|
act_out: torch.Tensor,
|
||||||
|
quant_out: torch.Tensor,
|
||||||
|
use_ue8m0: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
torch.ops._C.silu_and_mul(act_out, input)
|
||||||
|
return reference_quant(act_out, quant_out, use_ue8m0)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_impl(
|
||||||
|
bench_tensors: list[BenchmarkTensors], impl_type: ImplType
|
||||||
|
) -> TMeasurement:
|
||||||
|
T = bench_tensors[0].T
|
||||||
|
N = bench_tensors[0].N
|
||||||
|
|
||||||
|
arg_pool_size = len(bench_tensors)
|
||||||
|
kwargs_list = [bt.make_impl_kwargs(impl_type) for bt in bench_tensors]
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
for kwargs in kwargs_list:
|
||||||
|
impl_type.get_impl()(**kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Merge into a single kwargs and qualify arguments as ArgPool
|
||||||
|
kwargs = {k: ArgPool([]) for k in kwargs_list[0]}
|
||||||
|
for _kwargs in kwargs_list:
|
||||||
|
for k, v in _kwargs.items():
|
||||||
|
kwargs[k].values.append(v)
|
||||||
|
|
||||||
|
cuda_graph_params = None
|
||||||
|
cuda_graph_params = CudaGraphBenchParams(arg_pool_size)
|
||||||
|
timer = None
|
||||||
|
with Bench(
|
||||||
|
cuda_graph_params,
|
||||||
|
"silu-mul-quant",
|
||||||
|
f"num_tokens={T}, N={N}",
|
||||||
|
impl_type.name,
|
||||||
|
impl_type.get_impl(),
|
||||||
|
**kwargs,
|
||||||
|
) as bench:
|
||||||
|
timer = bench.run()
|
||||||
|
return timer
|
||||||
|
|
||||||
|
|
||||||
|
def test_correctness(T: int, N: int):
|
||||||
|
print(f"Testing num_tokens={T}, N={N} ...")
|
||||||
|
|
||||||
|
bench_tensor = BenchmarkTensors.make(T, N)
|
||||||
|
|
||||||
|
def output_from_impl(impl: ImplType) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
return impl.get_impl()(**bench_tensor.make_impl_kwargs(impl))
|
||||||
|
|
||||||
|
# reference output
|
||||||
|
ref_out_q, ref_out_s = output_from_impl(ImplType.REFERENCE)
|
||||||
|
|
||||||
|
# test ouptut
|
||||||
|
out_q, out_s = output_from_impl(
|
||||||
|
ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(ref_out_q.to(torch.float32), out_q.to(torch.float32))
|
||||||
|
torch.testing.assert_close(ref_out_s, out_s)
|
||||||
|
|
||||||
|
|
||||||
|
def run(Ts: list[int], Ns: list[int], arg_pool_size: int) -> list[TMeasurement]:
|
||||||
|
timers = []
|
||||||
|
for N, T in product(Ns, Ts):
|
||||||
|
test_correctness(T, N)
|
||||||
|
|
||||||
|
bench_tensors: list[BenchmarkTensors] = [
|
||||||
|
BenchmarkTensors.make(T, N) for _ in range(arg_pool_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
silu_mul_quant_timer = bench_impl(
|
||||||
|
bench_tensors, ImplType.SILU_MUL_PER_TOKEN_GROUP_QUANT_FP8_COLMAJOR
|
||||||
|
)
|
||||||
|
timers.append(silu_mul_quant_timer)
|
||||||
|
reference_timer = bench_impl(bench_tensors, ImplType.REFERENCE)
|
||||||
|
timers.append(reference_timer)
|
||||||
|
|
||||||
|
print_timers(
|
||||||
|
[silu_mul_quant_timer, reference_timer], cuda_graph_nops=arg_pool_size
|
||||||
|
)
|
||||||
|
|
||||||
|
print_timers(timers, cuda_graph_nops=arg_pool_size)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
T = [128 * i for i in range(1, 16)] + [2048 * i for i in range(1, 65)]
|
||||||
|
N = [2048, 4096, 8192]
|
||||||
|
|
||||||
|
print(f"T = {T}, N = {N}")
|
||||||
|
run(T, N, arg_pool_size=8)
|
||||||
105
benchmarks/kernels/benchmark_activation.py
Normal file
105
benchmarks/kernels/benchmark_activation.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# benchmark custom activation op performance
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.activation # noqa F401
|
||||||
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
|
batch_size_range = [1, 16, 32, 64, 128]
|
||||||
|
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||||
|
intermediate_size = [3072, 9728, 12288]
|
||||||
|
configs = list(itertools.product(batch_size_range, seq_len_range, intermediate_size))
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_activation(
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
provider: str,
|
||||||
|
func_name: str,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
device = "cuda"
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
dim = intermediate_size
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
if func_name == "gelu_and_mul":
|
||||||
|
layer = CustomOp.op_registry[func_name](approximate="none")
|
||||||
|
elif func_name == "gelu_and_mul_tanh":
|
||||||
|
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh")
|
||||||
|
elif func_name == "fatrelu_and_mul":
|
||||||
|
threshold = 0.5
|
||||||
|
layer = CustomOp.op_registry[func_name](threshold)
|
||||||
|
else:
|
||||||
|
layer = CustomOp.op_registry[func_name]()
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
|
||||||
|
compiled_layer = torch.compile(layer.forward_native)
|
||||||
|
|
||||||
|
if provider == "custom":
|
||||||
|
fn = lambda: layer(x)
|
||||||
|
elif provider == "compiled":
|
||||||
|
fn = lambda: compiled_layer(x)
|
||||||
|
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||||
|
fn, quantiles=[0.5, 0.2, 0.8]
|
||||||
|
)
|
||||||
|
return ms, max_ms, min_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(description="Benchmark the custom activation op.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--func-name",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"mul_and_silu",
|
||||||
|
"silu_and_mul",
|
||||||
|
"gelu_and_mul",
|
||||||
|
"gelu_and_mul_tanh",
|
||||||
|
"fatrelu_and_mul",
|
||||||
|
"swigluoai_and_mul",
|
||||||
|
"gelu_new",
|
||||||
|
"gelu_fast",
|
||||||
|
"quick_gelu",
|
||||||
|
],
|
||||||
|
default="silu_and_mul",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="bfloat16"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args
|
||||||
|
|
||||||
|
func_name = args.func_name
|
||||||
|
dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype]
|
||||||
|
|
||||||
|
perf_report = triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["batch_size", "seq_len", "intermediate_size"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["custom", "compiled"],
|
||||||
|
line_names=["Custom OP", "Compiled"],
|
||||||
|
styles=[("blue", "-"), ("green", "-")],
|
||||||
|
ylabel="ms",
|
||||||
|
plot_name=f"{func_name}-op-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
perf_report(
|
||||||
|
lambda batch_size, seq_len, intermediate_size, provider: benchmark_activation(
|
||||||
|
batch_size, seq_len, intermediate_size, provider, func_name, dtype
|
||||||
|
)
|
||||||
|
).run(print_data=True)
|
||||||
@@ -1,302 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
|
||||||
from vllm.model_executor.layers.quantization.aqlm import (
|
|
||||||
dequantize_weight, generic_dequantize_gemm, get_int_dtype,
|
|
||||||
optimized_dequantize_gemm)
|
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
||||||
|
|
||||||
|
|
||||||
def torch_mult(
|
|
||||||
input: torch.Tensor, # [..., in_features]
|
|
||||||
weights: torch.Tensor,
|
|
||||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
|
||||||
) -> torch.Tensor:
|
|
||||||
output = F.linear(input, weights)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def dequant_out_scale(
|
|
||||||
input: torch.Tensor, # [..., in_features]
|
|
||||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
|
||||||
codebooks: torch.
|
|
||||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
|
||||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
|
||||||
output_partition_sizes: torch.IntTensor,
|
|
||||||
bias: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
|
||||||
|
|
||||||
if bias is None:
|
|
||||||
output = F.linear(input, weights, bias)
|
|
||||||
orig_shape = output.shape
|
|
||||||
flattened_output = output.view(-1, output.size(-1))
|
|
||||||
f_scales = scales.view(-1, scales.shape[0])
|
|
||||||
b_scales = f_scales.expand(flattened_output.shape[0], -1)
|
|
||||||
flattened_output *= b_scales
|
|
||||||
return flattened_output.view(orig_shape)
|
|
||||||
else:
|
|
||||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
|
||||||
-1, weights.shape[1])
|
|
||||||
weights *= b_scales
|
|
||||||
return F.linear(input, weights, bias)
|
|
||||||
|
|
||||||
|
|
||||||
def dequant_weight_scale(
|
|
||||||
input: torch.Tensor, # [..., in_features]
|
|
||||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
|
||||||
codebooks: torch.
|
|
||||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
|
||||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
|
||||||
output_partition_sizes: torch.IntTensor,
|
|
||||||
bias: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
|
||||||
|
|
||||||
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
|
|
||||||
-1, weights.shape[1])
|
|
||||||
weights *= b_scales
|
|
||||||
return F.linear(input, weights, bias)
|
|
||||||
|
|
||||||
|
|
||||||
def dequant_no_scale(
|
|
||||||
input: torch.Tensor, # [..., in_features]
|
|
||||||
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
|
|
||||||
codebooks: torch.
|
|
||||||
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
|
|
||||||
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
|
|
||||||
output_partition_sizes: torch.IntTensor,
|
|
||||||
bias: Optional[torch.Tensor],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
|
|
||||||
|
|
||||||
return F.linear(input, weights, bias)
|
|
||||||
|
|
||||||
|
|
||||||
# Compare the optimized 1x16 and 2x8 cuda decompression/dequant kernels against
|
|
||||||
# the generic pytorch version.
|
|
||||||
# Just visual comparison.
|
|
||||||
def dequant_test(k: int, parts: torch.tensor, nbooks: int, bits: int) -> None:
|
|
||||||
|
|
||||||
n = parts.sum().item()
|
|
||||||
|
|
||||||
device = torch.device('cuda:0')
|
|
||||||
|
|
||||||
code_range = (1 << bits) // 2
|
|
||||||
ingroups = 8
|
|
||||||
|
|
||||||
codes = torch.randint(-code_range,
|
|
||||||
code_range,
|
|
||||||
size=(n, k // ingroups, nbooks),
|
|
||||||
dtype=get_int_dtype(bits),
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
|
||||||
dtype=torch.float16,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
for index in range(16):
|
|
||||||
for i in range(8):
|
|
||||||
for book in range(nbooks):
|
|
||||||
codebooks[book, index, 0, i] = count * (10**book)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
print("codes shape", codes.shape)
|
|
||||||
|
|
||||||
for i in range(16):
|
|
||||||
for book in range(nbooks):
|
|
||||||
codes[0, i, book] = i
|
|
||||||
codes[0, -i, book] = i
|
|
||||||
|
|
||||||
weights = dequantize_weight(codes, codebooks, None)
|
|
||||||
weights2 = ops.aqlm_dequant(codes, codebooks, parts)
|
|
||||||
|
|
||||||
print("weights shape:", weights.shape)
|
|
||||||
print("weights2 shape:", weights2.shape)
|
|
||||||
|
|
||||||
print("weights are:", weights)
|
|
||||||
print("weights2 are:", weights2)
|
|
||||||
|
|
||||||
print("first 128 weights are", weights[0, 0:128].to(torch.int32))
|
|
||||||
print("first 128 weights2 are:", weights2[0, 0:128].to(torch.int32))
|
|
||||||
|
|
||||||
print("last 128 weights are", weights[0, -128:])
|
|
||||||
print("last 128 weights2 are:", weights2[0, -128:])
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Benchmark aqlm performance.")
|
|
||||||
|
|
||||||
# Add arguments
|
|
||||||
parser.add_argument("--nbooks",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of codebooks (default: 1)")
|
|
||||||
parser.add_argument("--bits",
|
|
||||||
type=int,
|
|
||||||
default=16,
|
|
||||||
help="Number of bits per code element (default: 16)")
|
|
||||||
parser.add_argument(
|
|
||||||
"--test",
|
|
||||||
type=bool,
|
|
||||||
default=False,
|
|
||||||
help="Run the decompression/dequant tester rather than benchmarking "
|
|
||||||
"(default: False)")
|
|
||||||
|
|
||||||
# Parse the arguments
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Extract values
|
|
||||||
nbooks = args.nbooks
|
|
||||||
bits = args.bits
|
|
||||||
|
|
||||||
if args.test:
|
|
||||||
dequant_test(4096, torch.tensor((4096, )), nbooks, bits)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Otherwise, benchmark.
|
|
||||||
methods = [
|
|
||||||
ops.aqlm_gemm,
|
|
||||||
dequant_out_scale,
|
|
||||||
generic_dequantize_gemm,
|
|
||||||
optimized_dequantize_gemm,
|
|
||||||
dequant_weight_scale,
|
|
||||||
torch_mult,
|
|
||||||
dequant_no_scale,
|
|
||||||
]
|
|
||||||
|
|
||||||
filename = f"./aqlm_benchmark_{nbooks}x{bits}.csv"
|
|
||||||
print(f"writing benchmarks to file {filename}")
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
sys.stdout = f
|
|
||||||
|
|
||||||
print('m | k | n | n parts', end='')
|
|
||||||
for method in methods:
|
|
||||||
print(f" | {method.__name__.replace('_', ' ')} (µs)", end='')
|
|
||||||
print('')
|
|
||||||
|
|
||||||
# These are reasonable prefill sizes.
|
|
||||||
ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )),
|
|
||||||
(4096, (11008, 11008)), (11008, (4096, )))
|
|
||||||
|
|
||||||
# reasonable ranges for m.
|
|
||||||
for m in [
|
|
||||||
1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112,
|
|
||||||
128, 256, 512, 1024, 1536, 2048, 3072, 4096
|
|
||||||
]:
|
|
||||||
print(f'{m}', file=sys.__stdout__)
|
|
||||||
for ksp in ksandpartions:
|
|
||||||
run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits,
|
|
||||||
methods)
|
|
||||||
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
|
|
||||||
def run_grid(m: int, k: int, parts: torch.tensor, nbooks: int, bits: int,
|
|
||||||
methods):
|
|
||||||
|
|
||||||
# I didn't see visible improvements from increasing these, but feel free :)
|
|
||||||
num_warmup_trials = 1
|
|
||||||
num_trials = 1
|
|
||||||
|
|
||||||
num_calls = 100
|
|
||||||
|
|
||||||
# warmup.
|
|
||||||
for method in methods:
|
|
||||||
for _ in range(num_warmup_trials):
|
|
||||||
run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
m=m,
|
|
||||||
k=k,
|
|
||||||
parts=parts,
|
|
||||||
nbooks=nbooks,
|
|
||||||
bits=bits,
|
|
||||||
method=method,
|
|
||||||
)
|
|
||||||
|
|
||||||
n = parts.sum().item()
|
|
||||||
print(f'{m} | {k} | {n} | {parts.tolist()}', end='')
|
|
||||||
|
|
||||||
for method in methods:
|
|
||||||
best_time_us = 1e20
|
|
||||||
for _ in range(num_trials):
|
|
||||||
kernel_dur_ms = run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
m=m,
|
|
||||||
k=k,
|
|
||||||
parts=parts,
|
|
||||||
nbooks=nbooks,
|
|
||||||
bits=bits,
|
|
||||||
method=method,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_dur_us = 1000 * kernel_dur_ms
|
|
||||||
|
|
||||||
if kernel_dur_us < best_time_us:
|
|
||||||
best_time_us = kernel_dur_us
|
|
||||||
|
|
||||||
print(f' | {kernel_dur_us:.0f}', end='')
|
|
||||||
|
|
||||||
print('')
|
|
||||||
|
|
||||||
|
|
||||||
def run_timing(num_calls: int, m: int, k: int, parts: torch.tensor,
|
|
||||||
nbooks: int, bits: int, method) -> float:
|
|
||||||
|
|
||||||
n = parts.sum().item()
|
|
||||||
|
|
||||||
device = torch.device('cuda:0')
|
|
||||||
|
|
||||||
input = torch.randn((1, m, k), dtype=torch.float16, device=device)
|
|
||||||
|
|
||||||
code_range = (1 << bits) // 2
|
|
||||||
ingroups = 8
|
|
||||||
|
|
||||||
codes = torch.randint(-code_range,
|
|
||||||
code_range,
|
|
||||||
size=(n, k // ingroups, nbooks),
|
|
||||||
dtype=get_int_dtype(bits),
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8),
|
|
||||||
dtype=torch.float16,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device)
|
|
||||||
|
|
||||||
# for comparison to just a pytorch mult.
|
|
||||||
weights = torch.randn((n, k), dtype=torch.float16, device=device)
|
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
|
|
||||||
if method is torch_mult:
|
|
||||||
for i in range(num_calls):
|
|
||||||
torch_mult(input, weights, scales)
|
|
||||||
else:
|
|
||||||
for i in range(num_calls):
|
|
||||||
method(input, codes, codebooks, scales, parts, None)
|
|
||||||
|
|
||||||
end_event.record()
|
|
||||||
end_event.synchronize()
|
|
||||||
|
|
||||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
|
||||||
return dur_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
sys.exit(main())
|
|
||||||
244
benchmarks/kernels/benchmark_bitblas.py
Normal file
244
benchmarks/kernels/benchmark_bitblas.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Copyright (c) Microsoft Corporation.
|
||||||
|
# Licensed under the MIT License.
|
||||||
|
|
||||||
|
from packaging import version
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||||
|
MINIMUM_BITBLAS_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import bitblas
|
||||||
|
|
||||||
|
if version.parse(bitblas.__version__) < version.parse(MINIMUM_BITBLAS_VERSION):
|
||||||
|
raise ImportError(
|
||||||
|
"bitblas version is wrong. Please "
|
||||||
|
f"install bitblas>={MINIMUM_BITBLAS_VERSION}"
|
||||||
|
)
|
||||||
|
except ImportError as e:
|
||||||
|
bitblas_import_exception = e
|
||||||
|
raise ValueError(
|
||||||
|
"Trying to use the bitblas backend, but could not import"
|
||||||
|
f"with the following error: {bitblas_import_exception}. "
|
||||||
|
"Please install bitblas through the following command: "
|
||||||
|
f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`"
|
||||||
|
) from bitblas_import_exception
|
||||||
|
|
||||||
|
from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target
|
||||||
|
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark BitBLAS int4 on a specific target."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add arguments to the parser
|
||||||
|
parser.add_argument(
|
||||||
|
"--target",
|
||||||
|
type=str,
|
||||||
|
default=auto_detect_nvidia_target(),
|
||||||
|
help="Specify the target device for benchmarking.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group_size", type=int, default=None, help="Group size for grouped quantization."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--A_dtype",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
choices=["float16", "float32", "float64", "int32", "int8"],
|
||||||
|
help="Data type of activation A.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--W_dtype",
|
||||||
|
type=str,
|
||||||
|
default="int4",
|
||||||
|
choices=[
|
||||||
|
"float16",
|
||||||
|
"float32",
|
||||||
|
"float64",
|
||||||
|
"int32",
|
||||||
|
"int8",
|
||||||
|
"int4",
|
||||||
|
"int2",
|
||||||
|
"int1",
|
||||||
|
"nf4",
|
||||||
|
"fp4_e2m1",
|
||||||
|
],
|
||||||
|
help="Data type of weight W.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--accum_dtype",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
choices=["float16", "int32"],
|
||||||
|
help="Data type for accumulation.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out_dtype",
|
||||||
|
type=str,
|
||||||
|
default="float16",
|
||||||
|
choices=["float16", "float32", "int32", "int8"],
|
||||||
|
help="Data type for output.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--layout",
|
||||||
|
type=str,
|
||||||
|
default="nt",
|
||||||
|
choices=["nt", "nn"],
|
||||||
|
help="Matrix layout, 'nt' for non-transpose A and transpose W.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--with_bias", action="store_true", help="Include bias in the benchmark."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--with_scaling",
|
||||||
|
action="store_true",
|
||||||
|
help="Include scaling factor in the quantization.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--with_zeros", action="store_true", help="Include zeros in the quantization."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--zeros_mode",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["original", "rescale", "quantized"],
|
||||||
|
help="Specify the mode for calculating zeros.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the arguments
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Assign arguments to variables
|
||||||
|
target = args.target
|
||||||
|
A_dtype = args.A_dtype
|
||||||
|
W_dtype = args.W_dtype
|
||||||
|
accum_dtype = args.accum_dtype
|
||||||
|
out_dtype = args.out_dtype
|
||||||
|
layout = args.layout
|
||||||
|
with_bias = args.with_bias
|
||||||
|
group_size = args.group_size
|
||||||
|
with_scaling = args.with_scaling
|
||||||
|
with_zeros = args.with_zeros
|
||||||
|
zeros_mode = args.zeros_mode
|
||||||
|
|
||||||
|
# Define a list of shared arguments that repeat in every config
|
||||||
|
shared_args = [
|
||||||
|
A_dtype,
|
||||||
|
W_dtype,
|
||||||
|
out_dtype,
|
||||||
|
accum_dtype,
|
||||||
|
layout,
|
||||||
|
with_bias,
|
||||||
|
group_size,
|
||||||
|
with_scaling,
|
||||||
|
with_zeros,
|
||||||
|
zeros_mode,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Define just the (M, K, N) shapes in a more compact list
|
||||||
|
shapes = [
|
||||||
|
# square test
|
||||||
|
(1, 16384, 16384),
|
||||||
|
# BLOOM-176B
|
||||||
|
(1, 43008, 14336),
|
||||||
|
(1, 14336, 14336),
|
||||||
|
(1, 57344, 14336),
|
||||||
|
(1, 14336, 57344),
|
||||||
|
# OPT-65B
|
||||||
|
(1, 9216, 9216),
|
||||||
|
(1, 36864, 9216),
|
||||||
|
(1, 9216, 36864),
|
||||||
|
(1, 22016, 8192),
|
||||||
|
# LLAMA-70B/65B
|
||||||
|
(1, 8192, 22016),
|
||||||
|
(1, 8192, 8192),
|
||||||
|
(1, 28672, 8192),
|
||||||
|
(1, 8192, 28672),
|
||||||
|
# square test
|
||||||
|
(16384, 16384, 16384),
|
||||||
|
# BLOOM-176B
|
||||||
|
(8192, 43008, 14336),
|
||||||
|
(8192, 14336, 14336),
|
||||||
|
(8192, 57344, 14336),
|
||||||
|
(8192, 14336, 57344),
|
||||||
|
# OPT-65B
|
||||||
|
(8192, 9216, 9216),
|
||||||
|
(8192, 36864, 9216),
|
||||||
|
(8192, 9216, 36864),
|
||||||
|
(8192, 22016, 8192),
|
||||||
|
# LLAMA-70B/65B
|
||||||
|
(8192, 8192, 22016),
|
||||||
|
(8192, 8192, 8192),
|
||||||
|
(8192, 28672, 8192),
|
||||||
|
(8192, 8192, 28672),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Build test shapes with all the shared arguments
|
||||||
|
test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes]
|
||||||
|
|
||||||
|
benchmark_sets = []
|
||||||
|
benchmark_sets.extend(test_shapes)
|
||||||
|
|
||||||
|
benchmark_results = {}
|
||||||
|
for config_class, operator, input_args in benchmark_sets:
|
||||||
|
config = config_class(*input_args)
|
||||||
|
matmul = operator(config, target=target, enable_tuning=True)
|
||||||
|
kernel_latency = matmul.profile_latency()
|
||||||
|
|
||||||
|
print("Time cost is: {:.3f} ms".format(kernel_latency))
|
||||||
|
|
||||||
|
profile_config = {
|
||||||
|
f"{operator.__name__}-{'-'.join([str(i) for i in input_args])}": {
|
||||||
|
"BitBLAS_top20_latency": kernel_latency,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark_results.update(profile_config)
|
||||||
|
|
||||||
|
# Define headers for the table
|
||||||
|
headers = [
|
||||||
|
"PrimFunc",
|
||||||
|
"Input Arguments",
|
||||||
|
"BitBLAS Top20 Latency",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Calculate column widths for pretty printing
|
||||||
|
col_widths = [0, 0, 0]
|
||||||
|
for config_key, values in benchmark_results.items():
|
||||||
|
args_split = config_key.split("-")
|
||||||
|
func_name = args_split[0]
|
||||||
|
input_args_str = "-".join(args_split[1:])
|
||||||
|
col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2)
|
||||||
|
col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2)
|
||||||
|
col_widths[2] = max(
|
||||||
|
col_widths[2],
|
||||||
|
len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2,
|
||||||
|
len(headers[2]) + 2,
|
||||||
|
)
|
||||||
|
# break only if you want to measure widths from a single example;
|
||||||
|
# otherwise, let it loop over all items.
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
for i, header in enumerate(headers):
|
||||||
|
headers[i] = header.ljust(col_widths[i])
|
||||||
|
print("".join(headers))
|
||||||
|
print("-" * sum(col_widths))
|
||||||
|
|
||||||
|
# Print rows
|
||||||
|
for config_key, values in benchmark_results.items():
|
||||||
|
args_split = config_key.split("-")
|
||||||
|
func_name = args_split[0]
|
||||||
|
input_args_str = "-".join(args_split[1:])
|
||||||
|
row = [
|
||||||
|
func_name,
|
||||||
|
input_args_str,
|
||||||
|
f"{values['BitBLAS_top20_latency']:.3f} ms",
|
||||||
|
]
|
||||||
|
row_str = "".join(
|
||||||
|
[str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]
|
||||||
|
)
|
||||||
|
print(row_str)
|
||||||
504
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Normal file
504
benchmarks/kernels/benchmark_cutlass_fp4_moe.py
Normal file
@@ -0,0 +1,504 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe
|
||||||
|
kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit
|
||||||
|
activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8)
|
||||||
|
and 16-bit activations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import nvtx
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
fp8_w8a8_moe_quant_config,
|
||||||
|
nvfp4_moe_quant_config,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
WEIGHT_SHAPES_MOE = {
|
||||||
|
"nvidia/DeepSeek-R1-FP4": [
|
||||||
|
[256, 8, 2048, 7168],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_MODELS = [
|
||||||
|
"nvidia/DeepSeek-R1-FP4",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
PER_ACT_TOKEN_OPTS = [False]
|
||||||
|
PER_OUT_CH_OPTS = [False]
|
||||||
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||||
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor):
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(
|
||||||
|
results: list[benchmark.Measurement],
|
||||||
|
model: str,
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
mkn: tuple[int, int, int],
|
||||||
|
):
|
||||||
|
label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton"
|
||||||
|
|
||||||
|
sub_label = (
|
||||||
|
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
||||||
|
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
|
(m, k, n) = mkn
|
||||||
|
|
||||||
|
dtype = torch.half
|
||||||
|
device = "cuda"
|
||||||
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||||
|
|
||||||
|
_, a_fp8_scale = ops.scaled_fp8_quant(a)
|
||||||
|
|
||||||
|
w1_fp8q = torch.empty(
|
||||||
|
(num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn)
|
||||||
|
w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||||
|
w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
for expert in range(num_experts):
|
||||||
|
w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
|
||||||
|
w1_fp8q_notransp = w1_fp8q.clone()
|
||||||
|
w2_fp8q_notransp = w2_fp8q.clone()
|
||||||
|
w1_fp8q = w1_fp8q.transpose(1, 2)
|
||||||
|
w2_fp8q = w2_fp8q.transpose(1, 2)
|
||||||
|
|
||||||
|
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||||
|
|
||||||
|
quant_blocksize = 16
|
||||||
|
w1_blockscale = torch.empty(
|
||||||
|
(num_experts, 2 * n, k // quant_blocksize),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
)
|
||||||
|
w2_blockscale = torch.empty(
|
||||||
|
(num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
# n_b_scales = 2 * n if per_out_ch else 1
|
||||||
|
# k_b_scales = k if per_out_ch else 1
|
||||||
|
w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8)
|
||||||
|
w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
||||||
|
w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32)
|
||||||
|
a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
||||||
|
a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
for expert in range(num_experts):
|
||||||
|
w1_e = w1[expert]
|
||||||
|
w2_e = w2[expert]
|
||||||
|
w1_amax = torch.abs(w1_e).max().to(torch.float32)
|
||||||
|
w2_amax = torch.abs(w2_e).max().to(torch.float32)
|
||||||
|
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
|
||||||
|
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
|
||||||
|
|
||||||
|
w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant(
|
||||||
|
w1_e, w1_gs[expert]
|
||||||
|
)
|
||||||
|
|
||||||
|
w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant(
|
||||||
|
w2_e, w2_gs[expert]
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_triton_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a_fp8_scale: torch.Tensor,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_fp8_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_cutlass_moe_fp4(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1_fp4: torch.Tensor,
|
||||||
|
w2_fp4: torch.Tensor,
|
||||||
|
w1_blockscale: torch.Tensor,
|
||||||
|
w2_blockscale: torch.Tensor,
|
||||||
|
w1_gs: torch.Tensor,
|
||||||
|
w2_gs: torch.Tensor,
|
||||||
|
a1_gs: torch.Tensor,
|
||||||
|
a2_gs: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
device: torch.device,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = nvfp4_moe_quant_config(
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
g1_alphas=w1_gs,
|
||||||
|
g2_alphas=w2_gs,
|
||||||
|
)
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
with nvtx.annotate("cutlass_moe_fp4", color="green"):
|
||||||
|
cutlass_moe_fp4(
|
||||||
|
a=a,
|
||||||
|
w1_fp4=w1_fp4,
|
||||||
|
w2_fp4=w2_fp4,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
m=m,
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
e=num_experts,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_cutlass_from_graph(
|
||||||
|
a: torch.Tensor,
|
||||||
|
a1_gscale: torch.Tensor,
|
||||||
|
w1_fp4: torch.Tensor,
|
||||||
|
w1_blockscale: torch.Tensor,
|
||||||
|
w1_alphas: torch.Tensor,
|
||||||
|
a2_gscale: torch.Tensor,
|
||||||
|
w2_fp4: torch.Tensor,
|
||||||
|
w2_blockscale: torch.Tensor,
|
||||||
|
w2_alphas: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
e: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
quant_config = nvfp4_moe_quant_config(
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w1_scale=w1_blockscale,
|
||||||
|
w2_scale=w2_blockscale,
|
||||||
|
g1_alphas=w1_gs,
|
||||||
|
g2_alphas=w2_gs,
|
||||||
|
)
|
||||||
|
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
|
):
|
||||||
|
return cutlass_moe_fp4(
|
||||||
|
a=a,
|
||||||
|
w1_fp4=w1_fp4,
|
||||||
|
w2_fp4=w2_fp4,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
m=m,
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
e=num_experts,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_triton_from_graph(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a_fp8_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_fp8_scale,
|
||||||
|
)
|
||||||
|
return fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def replay_graph(graph, num_repeats):
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
cutlass_stream = torch.cuda.Stream()
|
||||||
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||||
|
run_cutlass_from_graph(
|
||||||
|
a=a,
|
||||||
|
a1_gscale=a1_gs,
|
||||||
|
w1_fp4=w1_fp4,
|
||||||
|
w1_blockscale=w1_blockscale,
|
||||||
|
w1_alphas=w1_gs,
|
||||||
|
a2_gscale=a2_gs,
|
||||||
|
w2_fp4=w2_fp4,
|
||||||
|
w2_blockscale=w2_blockscale,
|
||||||
|
w2_alphas=w2_gs,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
m=m,
|
||||||
|
n=n,
|
||||||
|
k=k,
|
||||||
|
e=num_experts,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
triton_stream = torch.cuda.Stream()
|
||||||
|
triton_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
|
run_triton_from_graph(
|
||||||
|
a,
|
||||||
|
w1_fp8q_notransp,
|
||||||
|
w2_fp8q_notransp,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_fp8scale,
|
||||||
|
w2_fp8scale,
|
||||||
|
a_fp8_scale,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
min_run_time = 5
|
||||||
|
num_warmup = 5
|
||||||
|
num_runs = 25
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
# Baseline params
|
||||||
|
"w1": w1,
|
||||||
|
"w2": w2,
|
||||||
|
"score": score,
|
||||||
|
"topk": topk,
|
||||||
|
"w1_fp8q_notransp": w1_fp8q_notransp,
|
||||||
|
"w2_fp8q_notransp": w2_fp8q_notransp,
|
||||||
|
"w1_fp8scale": w1_fp8scale,
|
||||||
|
"w2_fp8scale": w2_fp8scale,
|
||||||
|
"a_fp8_scale": a_fp8_scale,
|
||||||
|
# Cutlass params
|
||||||
|
"a": a,
|
||||||
|
"a1_gscale": a1_gs,
|
||||||
|
"w1_fp4": w1_fp4,
|
||||||
|
"w1_blockscale": w1_blockscale,
|
||||||
|
"w1_alphas": w1_gs,
|
||||||
|
"a2_gscale": a2_gs,
|
||||||
|
"w2_fp4": w2_fp4,
|
||||||
|
"w2_blockscale": w2_blockscale,
|
||||||
|
"w2_alphas": w2_gs,
|
||||||
|
"topk_weights": topk_weights,
|
||||||
|
"topk_ids": topk_ids,
|
||||||
|
"m": m,
|
||||||
|
"n": n,
|
||||||
|
"k": k,
|
||||||
|
"e": num_experts,
|
||||||
|
"device": device,
|
||||||
|
# cuda graph params
|
||||||
|
"cutlass_graph": cutlass_graph,
|
||||||
|
"triton_graph": triton_graph,
|
||||||
|
# Gen params
|
||||||
|
"num_runs": num_runs,
|
||||||
|
# Kernels
|
||||||
|
"run_triton_moe": run_triton_moe,
|
||||||
|
"run_cutlass_moe_fp4": run_cutlass_moe_fp4,
|
||||||
|
"replay_graph": replay_graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
run_triton_moe(
|
||||||
|
a,
|
||||||
|
w1_fp8q_notransp,
|
||||||
|
w2_fp8q_notransp,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_fp8scale,
|
||||||
|
w2_fp8scale,
|
||||||
|
a_fp8_scale,
|
||||||
|
num_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="triton_moe",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
replay_graph(triton_graph, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="replay_graph(triton_graph, num_runs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="triton_moe_cuda_graphs",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
|
||||||
|
run_cutlass_moe_fp4(
|
||||||
|
a,
|
||||||
|
w1_fp4,
|
||||||
|
w2_fp4,
|
||||||
|
w1_blockscale,
|
||||||
|
w2_blockscale,
|
||||||
|
w1_gs,
|
||||||
|
w2_gs,
|
||||||
|
a1_gs,
|
||||||
|
a2_gs,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
num_experts,
|
||||||
|
device,
|
||||||
|
num_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="cutlass_moe_fp4",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
replay_graph(cutlass_graph, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="cutlass_moe_fp4_cuda_graphs",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
results: list[benchmark.Measurement] = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for tp in args.tp_sizes:
|
||||||
|
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||||
|
num_experts = layer[0]
|
||||||
|
topk = layer[1]
|
||||||
|
size_k = layer[2]
|
||||||
|
size_n = layer[3] // tp
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||||
|
for per_out_ch in PER_OUT_CH_OPTS:
|
||||||
|
for size_m in args.batch_sizes:
|
||||||
|
mkn = (size_m, size_k, size_n)
|
||||||
|
bench_run(
|
||||||
|
results,
|
||||||
|
model,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
per_act_token,
|
||||||
|
per_out_ch,
|
||||||
|
mkn,
|
||||||
|
)
|
||||||
|
|
||||||
|
compare = benchmark.Compare(results)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
406
benchmarks/kernels/benchmark_cutlass_moe_fp8.py
Normal file
@@ -0,0 +1,406 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark the performance of the cutlass_moe_fp8 kernel vs the triton_moe
|
||||||
|
kernel. Both kernels take in fp8 quantized weights and 16-bit activations,
|
||||||
|
but use different quantization strategies and backends.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import nvtx
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
# Weight shapes for different models: [num_experts, topk, hidden_size,
|
||||||
|
# intermediate_size]
|
||||||
|
WEIGHT_SHAPES_MOE = {
|
||||||
|
"mixtral-8x7b": [
|
||||||
|
[8, 2, 4096, 14336],
|
||||||
|
],
|
||||||
|
"deepseek-v2": [
|
||||||
|
[160, 6, 5120, 12288],
|
||||||
|
],
|
||||||
|
"custom-small": [
|
||||||
|
[8, 2, 2048, 7168],
|
||||||
|
],
|
||||||
|
"glm45-fp8": [
|
||||||
|
[128, 8, 4096, 1408],
|
||||||
|
],
|
||||||
|
"Llama-4-Maverick-17B-128E-Instruct-FP8": [
|
||||||
|
[128, 1, 5120, 8192],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
DEFAULT_MODELS = [
|
||||||
|
"mixtral-8x7b",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
PER_ACT_TOKEN_OPTS = [False, True]
|
||||||
|
PER_OUT_CH_OPTS = [False, True]
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(
|
||||||
|
results: list,
|
||||||
|
model: str,
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
mkn: tuple[int, int, int],
|
||||||
|
):
|
||||||
|
(m, k, n) = mkn
|
||||||
|
|
||||||
|
dtype = torch.half
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# Create input activations
|
||||||
|
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||||
|
|
||||||
|
# Create weights
|
||||||
|
w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10
|
||||||
|
|
||||||
|
# Create FP8 quantized weights and scales for both kernels
|
||||||
|
w1_fp8q = torch.empty((num_experts, 2 * n, k), device=device, dtype=FP8_DTYPE)
|
||||||
|
w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=FP8_DTYPE)
|
||||||
|
|
||||||
|
# Create scales based on quantization strategy
|
||||||
|
if per_out_ch:
|
||||||
|
# Per-channel quantization
|
||||||
|
w1_scale = torch.empty(
|
||||||
|
(num_experts, 2 * n, 1), device=device, dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.empty((num_experts, k, 1), device=device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
# Per-tensor quantization
|
||||||
|
w1_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||||
|
w2_scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Quantize weights
|
||||||
|
for expert in range(num_experts):
|
||||||
|
if per_out_ch:
|
||||||
|
# Per-channel quantization - not yet implemented properly
|
||||||
|
# For now, fall back to per-tensor quantization
|
||||||
|
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
# Expand scalar scales to the expected per-channel shape
|
||||||
|
w1_scale[expert] = w1_scale_temp.expand(2 * n, 1)
|
||||||
|
w2_scale[expert] = w2_scale_temp.expand(k, 1)
|
||||||
|
else:
|
||||||
|
# Per-tensor quantization
|
||||||
|
w1_fp8q[expert], w1_scale_temp = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_fp8q[expert], w2_scale_temp = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
# Store scalar scales in [1, 1] tensors
|
||||||
|
w1_scale[expert, 0, 0] = w1_scale_temp
|
||||||
|
w2_scale[expert, 0, 0] = w2_scale_temp
|
||||||
|
|
||||||
|
# Prepare weights for CUTLASS (no transpose needed)
|
||||||
|
w1_fp8q_cutlass = w1_fp8q # Keep original [E, 2N, K]
|
||||||
|
w2_fp8q_cutlass = w2_fp8q # Keep original [E, K, N]
|
||||||
|
|
||||||
|
# Create router scores and get topk
|
||||||
|
score = torch.randn((m, num_experts), device=device, dtype=dtype)
|
||||||
|
topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False)
|
||||||
|
|
||||||
|
# WORKAROUND: CUTLASS MoE FP8 has issues with per-token quantization
|
||||||
|
# Force per-tensor quantization for all cases to match working e2e setup
|
||||||
|
a1_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||||
|
a2_scale = torch.full((), 1e-2, device=device, dtype=torch.float32)
|
||||||
|
|
||||||
|
# Force per-tensor quantization for all cases
|
||||||
|
per_act_token = False
|
||||||
|
|
||||||
|
# Create stride tensors for CUTLASS
|
||||||
|
ab_strides1 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||||
|
ab_strides2 = torch.full((num_experts,), n, dtype=torch.int64, device=device)
|
||||||
|
c_strides1 = torch.full((num_experts,), 2 * n, dtype=torch.int64, device=device)
|
||||||
|
c_strides2 = torch.full((num_experts,), k, dtype=torch.int64, device=device)
|
||||||
|
|
||||||
|
def run_triton_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
a2_scale: torch.Tensor,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_cutlass_moe_fp8(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a1_scale: torch.Tensor,
|
||||||
|
a2_scale: torch.Tensor,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
with nvtx.annotate("cutlass_moe_fp8", color="blue"):
|
||||||
|
cutlass_moe_fp8(
|
||||||
|
a=a,
|
||||||
|
w1_q=w1,
|
||||||
|
w2_q=w2,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
ab_strides1=ab_strides1,
|
||||||
|
ab_strides2=ab_strides2,
|
||||||
|
c_strides1=c_strides1,
|
||||||
|
c_strides2=c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
activation="silu",
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-create quantization config to avoid creating it inside CUDA graph
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
per_out_ch_quant=per_out_ch,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create CUDA graphs for CUTLASS (match benchmark_moe.py pattern exactly)
|
||||||
|
cutlass_stream = torch.cuda.Stream()
|
||||||
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||||
|
# Capture 10 invocations like benchmark_moe.py
|
||||||
|
for _ in range(10):
|
||||||
|
cutlass_moe_fp8(
|
||||||
|
a=a,
|
||||||
|
w1_q=w1_fp8q_cutlass,
|
||||||
|
w2_q=w2_fp8q_cutlass,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
ab_strides1=ab_strides1,
|
||||||
|
ab_strides2=ab_strides2,
|
||||||
|
c_strides1=c_strides1,
|
||||||
|
c_strides2=c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
activation="silu",
|
||||||
|
global_num_experts=num_experts,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Create CUDA graphs for Triton (match benchmark_moe.py pattern exactly)
|
||||||
|
triton_stream = torch.cuda.Stream()
|
||||||
|
triton_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
|
# Capture 10 invocations like benchmark_moe.py
|
||||||
|
for _ in range(10):
|
||||||
|
fused_experts(
|
||||||
|
a,
|
||||||
|
w1_fp8q,
|
||||||
|
w2_fp8q,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
def bench_cuda_graph(graph, num_warmup=5, num_iters=100):
|
||||||
|
"""Benchmark CUDA graph using events like benchmark_moe.py"""
|
||||||
|
# Warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
start_event = torch.Event(enable_timing=True)
|
||||||
|
end_event = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies = []
|
||||||
|
for _ in range(num_iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
|
||||||
|
# Divide by 10 since graph contains 10 calls
|
||||||
|
return sum(latencies) / (num_iters * 10)
|
||||||
|
|
||||||
|
# Benchmark parameters
|
||||||
|
num_warmup = 5
|
||||||
|
num_iters = 100
|
||||||
|
|
||||||
|
# Benchmark only CUDA graphs (more reliable and faster)
|
||||||
|
# Benchmark Triton MoE with CUDA graphs
|
||||||
|
triton_graph_time = bench_cuda_graph(
|
||||||
|
triton_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Benchmark CUTLASS MoE with CUDA graphs
|
||||||
|
cutlass_graph_time = bench_cuda_graph(
|
||||||
|
cutlass_graph, num_warmup=num_warmup, num_iters=num_iters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert ms to us and return results
|
||||||
|
triton_time_us = triton_graph_time * 1000
|
||||||
|
cutlass_time_us = cutlass_graph_time * 1000
|
||||||
|
|
||||||
|
return {
|
||||||
|
"batch_size": m,
|
||||||
|
"triton_time_us": triton_time_us,
|
||||||
|
"cutlass_time_us": cutlass_time_us,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for tp in args.tp_sizes:
|
||||||
|
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||||
|
num_experts = layer[0]
|
||||||
|
topk = layer[1]
|
||||||
|
size_k = layer[2]
|
||||||
|
size_n = layer[3] // tp
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for per_act_token in args.per_act_token_opts:
|
||||||
|
for per_out_ch in args.per_out_ch_opts:
|
||||||
|
print(
|
||||||
|
f"\n=== {model}, experts={num_experts}, topk={topk},"
|
||||||
|
f"per_act={per_act_token}, per_out_ch={per_out_ch} ==="
|
||||||
|
)
|
||||||
|
|
||||||
|
config_results = []
|
||||||
|
for size_m in args.batch_sizes:
|
||||||
|
mkn = (size_m, size_k, size_n)
|
||||||
|
result = bench_run(
|
||||||
|
[], # Not used anymore
|
||||||
|
model,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
per_act_token,
|
||||||
|
per_out_ch,
|
||||||
|
mkn,
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
config_results.append(result)
|
||||||
|
|
||||||
|
# Print results table for this configuration
|
||||||
|
if config_results:
|
||||||
|
print(
|
||||||
|
f"\n{'Batch Size':<12}"
|
||||||
|
f"{'Triton (us)':<15}"
|
||||||
|
f"{'CUTLASS (us)':<15}"
|
||||||
|
)
|
||||||
|
print("-" * 45)
|
||||||
|
for result in config_results:
|
||||||
|
print(
|
||||||
|
f"{result['batch_size']:<12}"
|
||||||
|
f"{result['triton_time_us']:<15.2f}"
|
||||||
|
f"{result['cutlass_time_us']:<15.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_results.extend(config_results)
|
||||||
|
|
||||||
|
print(f"\nTotal benchmarks completed: {len(all_results)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""Benchmark CUTLASS FP8 MOE vs Triton FP8 FUSED MOE
|
||||||
|
across specified models/shapes/batches
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
python benchmark_cutlass_moe_fp8.py \
|
||||||
|
--model "Llama-4-Maverick-17B-128E-Instruct-FP8" \
|
||||||
|
--tp-sizes 8 \
|
||||||
|
--batch-size 2 4 8 \
|
||||||
|
--per-act-token-opts false \
|
||||||
|
--per-out-ch-opts false
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-act-token-opts",
|
||||||
|
nargs="+",
|
||||||
|
type=lambda x: x.lower() == "true",
|
||||||
|
default=[False, True],
|
||||||
|
help="Per-activation token quantization options (true/false)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per-out-ch-opts",
|
||||||
|
nargs="+",
|
||||||
|
type=lambda x: x.lower() == "true",
|
||||||
|
default=[False, True],
|
||||||
|
help="Per-output channel quantization options (true/false)",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
508
benchmarks/kernels/benchmark_device_communicators.py
Normal file
508
benchmarks/kernels/benchmark_device_communicators.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Benchmark script for device communicators:
|
||||||
|
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
|
||||||
|
and SymmMemCommunicator (multimem, two-shot).
|
||||||
|
|
||||||
|
for NCCL symmetric memory you need to set the environment variables
|
||||||
|
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
|
||||||
|
not use fast NVLS implementation for all reduce.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
torchrun --nproc_per_node=2 benchmark_device_communicators.py
|
||||||
|
--sequence-lengths 512 1024 2048 --num-warmup 10 --num-trials 100
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
|
||||||
|
from vllm.distributed.device_communicators.pynccl import (
|
||||||
|
PyNcclCommunicator,
|
||||||
|
register_nccl_symmetric_ops,
|
||||||
|
)
|
||||||
|
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||||
|
set_graph_pool_id,
|
||||||
|
)
|
||||||
|
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Default sequence lengths to benchmark
|
||||||
|
DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192]
|
||||||
|
|
||||||
|
# Fixed hidden size and dtype for all benchmarks
|
||||||
|
HIDDEN_SIZE = 8192
|
||||||
|
BENCHMARK_DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
# CUDA graph settings
|
||||||
|
CUDA_GRAPH_CAPTURE_CYCLES = 10
|
||||||
|
|
||||||
|
|
||||||
|
class CommunicatorBenchmark:
|
||||||
|
"""Benchmark class for testing device communicators."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rank: int,
|
||||||
|
world_size: int,
|
||||||
|
device: torch.device,
|
||||||
|
cpu_group: ProcessGroup,
|
||||||
|
sequence_lengths: list[int],
|
||||||
|
):
|
||||||
|
self.rank = rank
|
||||||
|
self.world_size = world_size
|
||||||
|
self.device = device
|
||||||
|
self.cpu_group = cpu_group
|
||||||
|
|
||||||
|
# Calculate max_size_override based on largest sequence length
|
||||||
|
max_seq_len = max(sequence_lengths)
|
||||||
|
max_tensor_elements = max_seq_len * HIDDEN_SIZE
|
||||||
|
self.max_size_override = max_tensor_elements * BENCHMARK_DTYPE.itemsize + 1
|
||||||
|
|
||||||
|
# Initialize communicators
|
||||||
|
self.custom_allreduce = None
|
||||||
|
self.pynccl_comm = None
|
||||||
|
self.symm_mem_comm = None
|
||||||
|
self.symm_mem_comm_multimem = None
|
||||||
|
self.symm_mem_comm_two_shot = None
|
||||||
|
|
||||||
|
self._init_communicators()
|
||||||
|
|
||||||
|
def _init_communicators(self):
|
||||||
|
"""Initialize all available communicators."""
|
||||||
|
try:
|
||||||
|
self.custom_allreduce = CustomAllreduce(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
max_size=self.max_size_override,
|
||||||
|
)
|
||||||
|
if not self.custom_allreduce.disabled:
|
||||||
|
logger.info("Rank %s: CustomAllreduce initialized", self.rank)
|
||||||
|
else:
|
||||||
|
logger.info("Rank %s: CustomAllreduce disabled", self.rank)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize CustomAllreduce: %s", self.rank, e
|
||||||
|
)
|
||||||
|
self.custom_allreduce = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.pynccl_comm = PyNcclCommunicator(
|
||||||
|
group=self.cpu_group, device=self.device
|
||||||
|
)
|
||||||
|
if not self.pynccl_comm.disabled:
|
||||||
|
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
|
||||||
|
register_nccl_symmetric_ops(self.pynccl_comm)
|
||||||
|
else:
|
||||||
|
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
|
||||||
|
self.pynccl_comm = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize PyNcclCommunicator: %s", self.rank, e
|
||||||
|
)
|
||||||
|
self.pynccl_comm = None
|
||||||
|
|
||||||
|
# Initialize variants for SymmMemCommunicator
|
||||||
|
try:
|
||||||
|
self.symm_mem_comm_multimem = SymmMemCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
force_multimem=True,
|
||||||
|
max_size_override=self.max_size_override,
|
||||||
|
)
|
||||||
|
if not self.symm_mem_comm_multimem.disabled:
|
||||||
|
logger.info(
|
||||||
|
"Rank %s: SymmMemCommunicator (multimem) initialized", self.rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.symm_mem_comm_multimem = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize SymmMemCommunicator (multimem): %s",
|
||||||
|
self.rank,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self.symm_mem_comm_multimem = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.symm_mem_comm_two_shot = SymmMemCommunicator(
|
||||||
|
group=self.cpu_group,
|
||||||
|
device=self.device,
|
||||||
|
force_multimem=False,
|
||||||
|
max_size_override=self.max_size_override,
|
||||||
|
)
|
||||||
|
if not self.symm_mem_comm_two_shot.disabled:
|
||||||
|
logger.info(
|
||||||
|
"Rank %s: SymmMemCommunicator (two_shot) initialized", self.rank
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.symm_mem_comm_two_shot = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Rank %s: Failed to initialize SymmMemCommunicator (two_shot): %s",
|
||||||
|
self.rank,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self.symm_mem_comm_two_shot = None
|
||||||
|
|
||||||
|
def benchmark_allreduce(
|
||||||
|
self, sequence_length: int, num_warmup: int, num_trials: int
|
||||||
|
) -> dict[str, float]:
|
||||||
|
"""Benchmark allreduce operations for all available communicators."""
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
# Define communicators with their benchmark functions
|
||||||
|
communicators = []
|
||||||
|
|
||||||
|
if self.custom_allreduce is not None:
|
||||||
|
comm = self.custom_allreduce
|
||||||
|
# CustomAllreduce one-shot
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"ca_1stage",
|
||||||
|
lambda t, c=comm: c.custom_all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_custom_ar(t),
|
||||||
|
comm.capture(),
|
||||||
|
"1stage", # env variable value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# CustomAllreduce two-shot
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"ca_2stage",
|
||||||
|
lambda t, c=comm: c.custom_all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_custom_ar(t),
|
||||||
|
comm.capture(),
|
||||||
|
"2stage", # env variable value
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.pynccl_comm is not None:
|
||||||
|
comm = self.pynccl_comm
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"pynccl",
|
||||||
|
lambda t, c=comm: c.all_reduce(t),
|
||||||
|
lambda t: True, # Always available if initialized
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"pynccl-symm",
|
||||||
|
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
|
||||||
|
lambda t: True, # Always available if initialized
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.symm_mem_comm_multimem is not None:
|
||||||
|
comm = self.symm_mem_comm_multimem
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"symm_mem_multimem",
|
||||||
|
lambda t, c=comm: c.all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.symm_mem_comm_two_shot is not None:
|
||||||
|
comm = self.symm_mem_comm_two_shot
|
||||||
|
communicators.append(
|
||||||
|
(
|
||||||
|
"symm_mem_two_shot",
|
||||||
|
lambda t, c=comm: c.all_reduce(t),
|
||||||
|
lambda t, c=comm: c.should_use_symm_mem(t),
|
||||||
|
nullcontext(),
|
||||||
|
None, # no env variable needed
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Benchmark each communicator
|
||||||
|
for name, allreduce_fn, should_use_fn, context, env_var in communicators:
|
||||||
|
# Set environment variable if needed
|
||||||
|
if env_var is not None:
|
||||||
|
os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var
|
||||||
|
else:
|
||||||
|
# Clear the environment variable to avoid interference
|
||||||
|
os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None)
|
||||||
|
|
||||||
|
latency = self.benchmark_allreduce_single(
|
||||||
|
sequence_length,
|
||||||
|
allreduce_fn,
|
||||||
|
should_use_fn,
|
||||||
|
context,
|
||||||
|
num_warmup,
|
||||||
|
num_trials,
|
||||||
|
)
|
||||||
|
if latency is not None:
|
||||||
|
results[name] = latency
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def benchmark_allreduce_single(
|
||||||
|
self,
|
||||||
|
sequence_length: int,
|
||||||
|
allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None],
|
||||||
|
should_use_fn: Callable[[torch.Tensor], bool],
|
||||||
|
context,
|
||||||
|
num_warmup: int,
|
||||||
|
num_trials: int,
|
||||||
|
) -> float | None:
|
||||||
|
"""Benchmark method with CUDA graph optimization."""
|
||||||
|
try:
|
||||||
|
# Create test tensor (2D: sequence_length x hidden_size)
|
||||||
|
tensor = torch.randn(
|
||||||
|
sequence_length, HIDDEN_SIZE, dtype=BENCHMARK_DTYPE, device=self.device
|
||||||
|
)
|
||||||
|
if not should_use_fn(tensor):
|
||||||
|
return None
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
graph_input = tensor.clone()
|
||||||
|
|
||||||
|
# Warmup before capture
|
||||||
|
for _ in range(3):
|
||||||
|
allreduce_fn(graph_input)
|
||||||
|
|
||||||
|
# Capture the graph using context manager
|
||||||
|
with context:
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
graph_pool = torch.cuda.graph_pool_handle()
|
||||||
|
set_graph_pool_id(graph_pool)
|
||||||
|
with torch.cuda.graph(graph, pool=graph_pool):
|
||||||
|
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||||
|
allreduce_fn(graph_input)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_trials):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Convert to ms and divide by CUDA_GRAPH_CAPTURE_CYCLES
|
||||||
|
return (
|
||||||
|
(end_time - start_time) / num_trials / CUDA_GRAPH_CAPTURE_CYCLES * 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("CUDA graph benchmark failed: %s", e)
|
||||||
|
raise RuntimeError(
|
||||||
|
f"CUDA graph benchmark failed for communicator: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_speedup_info(comm_results: dict[str, float]) -> str:
|
||||||
|
"""Calculate speedup information for a single tensor size."""
|
||||||
|
if not comm_results:
|
||||||
|
return "N/A"
|
||||||
|
|
||||||
|
# Find the fastest communicator
|
||||||
|
fastest_comm = min(comm_results.keys(), key=lambda k: comm_results[k])
|
||||||
|
fastest_time = comm_results[fastest_comm]
|
||||||
|
|
||||||
|
# Calculate speedup vs PyNccl if available
|
||||||
|
if "pynccl" in comm_results:
|
||||||
|
pynccl_time = comm_results["pynccl"]
|
||||||
|
speedup = pynccl_time / fastest_time
|
||||||
|
return f"{fastest_comm} ({speedup:.2f}x)"
|
||||||
|
else:
|
||||||
|
return f"{fastest_comm} (N/A)"
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(
|
||||||
|
results: dict[str, dict[str, float]], sequence_lengths: list[int], world_size: int
|
||||||
|
):
|
||||||
|
"""Print benchmark results in a formatted table."""
|
||||||
|
|
||||||
|
print(f"\n{'=' * 130}")
|
||||||
|
print("Device Communicator Benchmark Results")
|
||||||
|
print(
|
||||||
|
f"World Size: {world_size}, Data Type: {BENCHMARK_DTYPE}, "
|
||||||
|
f"Hidden Size: {HIDDEN_SIZE}"
|
||||||
|
)
|
||||||
|
print(f"{'=' * 130}")
|
||||||
|
|
||||||
|
# Get all communicator names
|
||||||
|
all_comms = set()
|
||||||
|
for size_results in results.values():
|
||||||
|
all_comms.update(size_results.keys())
|
||||||
|
|
||||||
|
all_comms = sorted(list(all_comms))
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
header = f"{'Tensor Shape':<20}{'Tensor Size':<15}"
|
||||||
|
for comm in all_comms:
|
||||||
|
header += f"{comm:<20}"
|
||||||
|
header += f"{'Best (Speedup vs PyNccl)':<30}"
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
# Print results for each sequence length
|
||||||
|
for seq_len in sequence_lengths:
|
||||||
|
if seq_len in results:
|
||||||
|
# Calculate tensor size in elements and bytes
|
||||||
|
tensor_elements = seq_len * HIDDEN_SIZE
|
||||||
|
tensor_bytes = tensor_elements * BENCHMARK_DTYPE.itemsize
|
||||||
|
|
||||||
|
# Format tensor size (MB)
|
||||||
|
tensor_size_mb = tensor_bytes / (1024 * 1024)
|
||||||
|
tensor_size_str = f"{tensor_size_mb:.2f} MB"
|
||||||
|
|
||||||
|
# Format tensor shape
|
||||||
|
tensor_shape = f"({seq_len}, {HIDDEN_SIZE})"
|
||||||
|
|
||||||
|
row = f"{tensor_shape:<20}{tensor_size_str:<15}"
|
||||||
|
for comm in all_comms:
|
||||||
|
if comm in results[seq_len]:
|
||||||
|
row += f"{results[seq_len][comm]:<20.3f}"
|
||||||
|
else:
|
||||||
|
row += f"{'N/A':<20}"
|
||||||
|
|
||||||
|
# Calculate speedup information
|
||||||
|
speedup_info = _calculate_speedup_info(results[seq_len])
|
||||||
|
row += f"{speedup_info:<30}"
|
||||||
|
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
print(f"{'=' * 130}")
|
||||||
|
print("All times are in milliseconds (ms) per allreduce operation")
|
||||||
|
print("Speedup column shows: fastest_algorithm (speedup_vs_pynccl)")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = FlexibleArgumentParser(description="Benchmark device communicators")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--sequence-lengths",
|
||||||
|
type=int,
|
||||||
|
nargs="+",
|
||||||
|
default=DEFAULT_SEQUENCE_LENGTHS,
|
||||||
|
help="Sequence lengths to benchmark (tensor shape: seq_len x hidden_size)",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-warmup", type=int, default=5, help="Number of warmup iterations"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-trials", type=int, default=50, help="Number of benchmark trials"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--output-json", type=str, help="Output results to JSON file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Initialize distributed
|
||||||
|
if not dist.is_initialized():
|
||||||
|
dist.init_process_group(backend="gloo")
|
||||||
|
rank = dist.get_rank()
|
||||||
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device(f"cuda:{rank}")
|
||||||
|
torch.cuda.set_device(device)
|
||||||
|
|
||||||
|
# Get CPU process group
|
||||||
|
cpu_group = dist.new_group(backend="gloo")
|
||||||
|
|
||||||
|
# Disable USE_SYMM_MEM to avoid affecting the max_sizes
|
||||||
|
# in symm_mem and custom_all_reduce for benchmark
|
||||||
|
os.environ["VLLM_ALLREDUCE_USE_SYMM_MEM"] = "0"
|
||||||
|
|
||||||
|
# Initialize benchmark
|
||||||
|
benchmark = CommunicatorBenchmark(
|
||||||
|
rank, world_size, device, cpu_group, args.sequence_lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run benchmarks
|
||||||
|
all_results = {}
|
||||||
|
|
||||||
|
for seq_len in args.sequence_lengths:
|
||||||
|
if rank == 0:
|
||||||
|
logger.info(
|
||||||
|
"Benchmarking sequence length: %s (tensor shape: %s x %s)",
|
||||||
|
seq_len,
|
||||||
|
seq_len,
|
||||||
|
HIDDEN_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
results = benchmark.benchmark_allreduce(
|
||||||
|
sequence_length=seq_len,
|
||||||
|
num_warmup=args.num_warmup,
|
||||||
|
num_trials=args.num_trials,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_results[seq_len] = results
|
||||||
|
|
||||||
|
# Synchronize between ranks
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# Print results (only rank 0)
|
||||||
|
if rank == 0:
|
||||||
|
print_results(all_results, args.sequence_lengths, world_size)
|
||||||
|
|
||||||
|
# Save to JSON if requested
|
||||||
|
if args.output_json:
|
||||||
|
# Add speedup information to results
|
||||||
|
enhanced_results = {}
|
||||||
|
for seq_len, comm_results in all_results.items():
|
||||||
|
enhanced_results[seq_len] = {
|
||||||
|
"timings": comm_results,
|
||||||
|
"speedup_info": _calculate_speedup_info(comm_results),
|
||||||
|
}
|
||||||
|
|
||||||
|
output_data = {
|
||||||
|
"world_size": world_size,
|
||||||
|
"dtype": str(BENCHMARK_DTYPE),
|
||||||
|
"hidden_size": HIDDEN_SIZE,
|
||||||
|
"sequence_lengths": args.sequence_lengths,
|
||||||
|
"num_warmup": args.num_warmup,
|
||||||
|
"num_trials": args.num_trials,
|
||||||
|
"cuda_graph_capture_cycles": CUDA_GRAPH_CAPTURE_CYCLES,
|
||||||
|
"results": enhanced_results,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(args.output_json, "w") as f:
|
||||||
|
json.dump(output_data, f, indent=2)
|
||||||
|
|
||||||
|
logger.info("Results saved to %s", args.output_json)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if cpu_group != dist.group.WORLD:
|
||||||
|
dist.destroy_process_group(cpu_group)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1129
benchmarks/kernels/benchmark_fused_collective.py
Normal file
1129
benchmarks/kernels/benchmark_fused_collective.py
Normal file
File diff suppressed because it is too large
Load Diff
427
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
427
benchmarks/kernels/benchmark_grouped_gemm_cutlass.py
Normal file
@@ -0,0 +1,427 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from benchmark_shapes import WEIGHT_SHAPES_MOE
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import fp8_w8a8_moe_quant_config
|
||||||
|
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||||
|
fused_experts,
|
||||||
|
fused_topk,
|
||||||
|
)
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = [
|
||||||
|
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite",
|
||||||
|
"ibm-granite/granite-3.0-1b-a400m",
|
||||||
|
"ibm-granite/granite-3.0-3b-a800m",
|
||||||
|
]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
PER_ACT_TOKEN_OPTS = [False]
|
||||||
|
PER_OUT_CH_OPTS = [False]
|
||||||
|
|
||||||
|
|
||||||
|
def to_fp8(tensor: torch.Tensor):
|
||||||
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to(
|
||||||
|
dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(
|
||||||
|
results: list[benchmark.Measurement],
|
||||||
|
model: str,
|
||||||
|
num_experts: int,
|
||||||
|
topk: int,
|
||||||
|
per_act_token: bool,
|
||||||
|
per_out_ch: bool,
|
||||||
|
mkn: tuple[int, int, int],
|
||||||
|
):
|
||||||
|
label = "Quant Matmul"
|
||||||
|
|
||||||
|
sub_label = (
|
||||||
|
"{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format(
|
||||||
|
model, num_experts, topk, per_act_token, per_out_ch, mkn
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
|
(m, k, n) = mkn
|
||||||
|
|
||||||
|
dtype = torch.half
|
||||||
|
|
||||||
|
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w1 = torch.randn((num_experts, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||||
|
w2 = torch.randn((num_experts, k, n), device="cuda", dtype=dtype) / 10
|
||||||
|
|
||||||
|
_, a_scale = ops.scaled_fp8_quant(a)
|
||||||
|
|
||||||
|
w1_q = torch.empty(
|
||||||
|
(num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn
|
||||||
|
)
|
||||||
|
w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn)
|
||||||
|
w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||||
|
w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32)
|
||||||
|
|
||||||
|
for expert in range(num_experts):
|
||||||
|
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert])
|
||||||
|
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert])
|
||||||
|
|
||||||
|
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
|
a, score, topk, renormalize=False
|
||||||
|
)
|
||||||
|
|
||||||
|
ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||||
|
ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64)
|
||||||
|
c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64)
|
||||||
|
|
||||||
|
def run_triton_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_scale,
|
||||||
|
)
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_cutlass_moe(
|
||||||
|
a: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
per_act_token: bool,
|
||||||
|
num_repeats: int,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
cutlass_moe_fp8(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_cutlass_from_graph(
|
||||||
|
a: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
w1_q: torch.Tensor,
|
||||||
|
w2_q: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
ab_strides1: torch.Tensor,
|
||||||
|
ab_strides2: torch.Tensor,
|
||||||
|
c_strides1: torch.Tensor,
|
||||||
|
c_strides2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
per_act_token_quant=per_act_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
|
):
|
||||||
|
return cutlass_moe_fp8(
|
||||||
|
a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_triton_from_graph(
|
||||||
|
a: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
w1_scale: torch.Tensor,
|
||||||
|
w2_scale: torch.Tensor,
|
||||||
|
a_scale: torch.Tensor,
|
||||||
|
):
|
||||||
|
quant_config = fp8_w8a8_moe_quant_config(
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a_scale,
|
||||||
|
)
|
||||||
|
with set_current_vllm_config(
|
||||||
|
VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1))
|
||||||
|
):
|
||||||
|
return fused_experts(
|
||||||
|
a,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def replay_graph(graph, num_repeats):
|
||||||
|
for _ in range(num_repeats):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
cutlass_stream = torch.cuda.Stream()
|
||||||
|
cutlass_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(cutlass_graph, stream=cutlass_stream):
|
||||||
|
run_cutlass_from_graph(
|
||||||
|
a,
|
||||||
|
a_scale,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
triton_stream = torch.cuda.Stream()
|
||||||
|
triton_graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(triton_graph, stream=triton_stream):
|
||||||
|
run_triton_from_graph(
|
||||||
|
a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a_scale,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
min_run_time = 5
|
||||||
|
num_warmup = 5
|
||||||
|
num_runs = 25
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
# Baseline params
|
||||||
|
"w1": w1,
|
||||||
|
"w2": w2,
|
||||||
|
"score": score,
|
||||||
|
"topk": topk,
|
||||||
|
# Cutlass params
|
||||||
|
"a_scale": a_scale,
|
||||||
|
"w1_q": w1_q,
|
||||||
|
"w2_q": w2_q,
|
||||||
|
"w1_scale": w1_scale,
|
||||||
|
"w2_scale": w2_scale,
|
||||||
|
"per_act_token": per_act_token,
|
||||||
|
"ab_strides1": ab_strides1,
|
||||||
|
"ab_strides2": ab_strides2,
|
||||||
|
"c_strides1": c_strides1,
|
||||||
|
"c_strides2": c_strides2,
|
||||||
|
# cuda graph params
|
||||||
|
"cutlass_graph": cutlass_graph,
|
||||||
|
"triton_graph": triton_graph,
|
||||||
|
# Gen params
|
||||||
|
"a": a,
|
||||||
|
"topk_weights": topk_weights,
|
||||||
|
"topk_ids": topk_ids,
|
||||||
|
"num_runs": num_runs,
|
||||||
|
# Kernels
|
||||||
|
"run_triton_moe": run_triton_moe,
|
||||||
|
"run_cutlass_moe": run_cutlass_moe,
|
||||||
|
"replay_graph": replay_graph,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
run_triton_moe(
|
||||||
|
a,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
a_scale,
|
||||||
|
num_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="triton_moe",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
replay_graph(triton_graph, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="replay_graph(triton_graph, num_runs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="triton_moe_cuda_graphs",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
run_cutlass_moe(
|
||||||
|
a,
|
||||||
|
a_scale,
|
||||||
|
w1_q,
|
||||||
|
w2_q,
|
||||||
|
w1_scale,
|
||||||
|
w2_scale,
|
||||||
|
ab_strides1,
|
||||||
|
ab_strides2,
|
||||||
|
c_strides1,
|
||||||
|
c_strides2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
per_act_token,
|
||||||
|
num_warmup,
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, ab_strides1, ab_strides2, c_strides1, c_strides2, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="grouped_gemm_moe",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
replay_graph(cutlass_graph, num_warmup)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="replay_graph(cutlass_graph, num_runs)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="grouped_gemm_moe_cuda_graphs",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
results: list[benchmark.Measurement] = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for tp in args.tp_sizes:
|
||||||
|
for layer in WEIGHT_SHAPES_MOE[model]:
|
||||||
|
num_experts = layer[0]
|
||||||
|
topk = layer[1]
|
||||||
|
size_k = layer[2]
|
||||||
|
size_n = layer[3] // tp
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for per_act_token in PER_ACT_TOKEN_OPTS:
|
||||||
|
for per_out_ch in PER_OUT_CH_OPTS:
|
||||||
|
for size_m in DEFAULT_BATCH_SIZES:
|
||||||
|
mkn = (size_m, size_k, size_n)
|
||||||
|
bench_run(
|
||||||
|
results,
|
||||||
|
model,
|
||||||
|
num_experts,
|
||||||
|
topk,
|
||||||
|
per_act_token,
|
||||||
|
per_out_ch,
|
||||||
|
mkn,
|
||||||
|
)
|
||||||
|
|
||||||
|
compare = benchmark.Compare(results)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark Marlin across specified models/shapes/batches"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES_MOE.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
94
benchmarks/kernels/benchmark_layernorm.py
Normal file
94
benchmarks/kernels/benchmark_layernorm.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
add_residual: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int = 0,
|
||||||
|
do_profile: bool = False,
|
||||||
|
num_warmup_iters: int = 5,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> None:
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
scale = 1 / (2 * hidden_size)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
x *= scale
|
||||||
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
|
||||||
|
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
layer(x, residual)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
|
# Warmup.
|
||||||
|
print("Warming up...")
|
||||||
|
run_benchmark = run_cuda_benchmark
|
||||||
|
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
if do_profile:
|
||||||
|
latency = run_benchmark(num_iters=1, profile=True)
|
||||||
|
else:
|
||||||
|
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||||
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
|
||||||
|
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||||
|
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||||
|
parser.add_argument("--add-residual", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-iters",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of benchmark iterations. "
|
||||||
|
"If --profile is set, this number is ignored",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
main(
|
||||||
|
num_tokens=args.num_tokens,
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
add_residual=args.add_residual,
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
seed=args.seed,
|
||||||
|
do_profile=args.profile,
|
||||||
|
num_warmup_iters=args.num_warmup_iters,
|
||||||
|
num_iters=args.num_iters,
|
||||||
|
)
|
||||||
1488
benchmarks/kernels/benchmark_lora.py
Normal file
1488
benchmarks/kernels/benchmark_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
745
benchmarks/kernels/benchmark_machete.py
Normal file
745
benchmarks/kernels/benchmark_machete.py
Normal file
@@ -0,0 +1,745 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import product
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from weight_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL,
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
marlin_permute_scales,
|
||||||
|
marlin_zero_points,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
|
MarlinWorkspace,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
pack_rows,
|
||||||
|
quantize_weights,
|
||||||
|
)
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
|
||||||
|
DEFAULT_TP_SIZES = [1]
|
||||||
|
|
||||||
|
NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)
|
||||||
|
|
||||||
|
if NVTX_PROFILE:
|
||||||
|
import nvtx
|
||||||
|
|
||||||
|
|
||||||
|
def terse_type_name(dt):
|
||||||
|
return {
|
||||||
|
torch.bfloat16: "bf16",
|
||||||
|
torch.float16: "fp16",
|
||||||
|
torch.int8: "int8",
|
||||||
|
torch.float8_e4m3fn: "fp8",
|
||||||
|
torch.float: "float",
|
||||||
|
torch.int: "int",
|
||||||
|
}[dt]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkTensors:
|
||||||
|
w_ref: torch.Tensor
|
||||||
|
a: torch.Tensor
|
||||||
|
|
||||||
|
w_q: torch.Tensor
|
||||||
|
group_size: int | None
|
||||||
|
wtype: ScalarType
|
||||||
|
w_g_s: torch.Tensor
|
||||||
|
w_g_zp: torch.Tensor | None
|
||||||
|
w_ch_s: torch.Tensor | None
|
||||||
|
w_tok_s: torch.Tensor | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TypeConfig:
|
||||||
|
act_type: torch.dtype
|
||||||
|
weight_type: ScalarType
|
||||||
|
output_type: torch.dtype | None
|
||||||
|
group_scale_type: torch.dtype | None
|
||||||
|
group_zero_type: torch.dtype | None
|
||||||
|
channel_scale_type: torch.dtype | None
|
||||||
|
token_scale_type: torch.dtype | None
|
||||||
|
|
||||||
|
|
||||||
|
def rand_data(shape, dtype=torch.float16, scale=1):
|
||||||
|
if dtype.is_floating_point:
|
||||||
|
return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype)
|
||||||
|
else:
|
||||||
|
return torch.randint(-15, 15, shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def quantize_and_pack(
|
||||||
|
atype: torch.dtype,
|
||||||
|
w: torch.Tensor,
|
||||||
|
wtype: ScalarType,
|
||||||
|
stype: torch.dtype | None,
|
||||||
|
group_size: int | None,
|
||||||
|
zero_points: bool = False,
|
||||||
|
):
|
||||||
|
assert wtype.is_integer(), "TODO: support floating point weights"
|
||||||
|
|
||||||
|
w_ref, w_q, w_s, w_zp = quantize_weights(
|
||||||
|
w,
|
||||||
|
wtype,
|
||||||
|
group_size=group_size,
|
||||||
|
zero_points=zero_points,
|
||||||
|
# to match how the kernel applies zps
|
||||||
|
ref_zero_points_after_scales=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
|
||||||
|
return w_ref, w_q, w_s, w_zp
|
||||||
|
|
||||||
|
|
||||||
|
def create_bench_tensors(
|
||||||
|
shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
|
||||||
|
) -> list[BenchmarkTensors]:
|
||||||
|
m, n, k = shape
|
||||||
|
|
||||||
|
# we want to make sure that weights don't fit into L2 cache between runs so
|
||||||
|
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
|
||||||
|
# so we target total weight size > 2*50mb
|
||||||
|
num_weights = math.ceil(
|
||||||
|
2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits)
|
||||||
|
)
|
||||||
|
|
||||||
|
a = rand_data((m, k), types.act_type, scale=5)
|
||||||
|
|
||||||
|
benchmark_tensors: list[BenchmarkTensors] = []
|
||||||
|
for _ in range(num_weights):
|
||||||
|
w = rand_data((k, n), types.act_type, scale=5)
|
||||||
|
|
||||||
|
if types.group_scale_type is not None:
|
||||||
|
w = w.to(types.group_scale_type)
|
||||||
|
if w.dtype.itemsize == 1:
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
|
||||||
|
w_ref, w_q_packed, w_s, w_zp = quantize_and_pack(
|
||||||
|
a.dtype,
|
||||||
|
w,
|
||||||
|
types.weight_type,
|
||||||
|
types.group_scale_type,
|
||||||
|
group_size,
|
||||||
|
types.group_zero_type is not None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not a.dtype.is_floating_point:
|
||||||
|
aiinfo = torch.iinfo(a.dtype)
|
||||||
|
w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)
|
||||||
|
|
||||||
|
w_ref = w_ref.to(torch.float32)
|
||||||
|
|
||||||
|
w_ch_s = (
|
||||||
|
None
|
||||||
|
if types.channel_scale_type is None
|
||||||
|
else rand_data((n,), types.channel_scale_type)
|
||||||
|
)
|
||||||
|
w_tok_s = (
|
||||||
|
None
|
||||||
|
if types.token_scale_type is None
|
||||||
|
else rand_data((m,), types.token_scale_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
benchmark_tensors.append(
|
||||||
|
BenchmarkTensors(
|
||||||
|
w_ref=w_ref,
|
||||||
|
a=a,
|
||||||
|
w_q=w_q_packed,
|
||||||
|
wtype=types.weight_type,
|
||||||
|
w_g_s=w_s,
|
||||||
|
w_g_zp=w_zp,
|
||||||
|
group_size=group_size,
|
||||||
|
w_ch_s=w_ch_s,
|
||||||
|
w_tok_s=w_tok_s,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return benchmark_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
|
a = bt.a
|
||||||
|
w = bt.w_ref.to(bt.a.dtype) # use float reference tensor
|
||||||
|
if a.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
a = a.to(torch.float16)
|
||||||
|
w = w.to(torch.float16)
|
||||||
|
return lambda: torch.matmul(a, w)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
|
if bt.w_ch_s is not None and bt.w_tok_s is not None:
|
||||||
|
scale_a = bt.w_tok_s.to(torch.float32)
|
||||||
|
scale_b = bt.w_ch_s.to(torch.float32)
|
||||||
|
else:
|
||||||
|
scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||||
|
scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device)
|
||||||
|
w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t()
|
||||||
|
return lambda: ops.cutlass_scaled_mm(
|
||||||
|
bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||||
|
device = bt.a.device
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(
|
||||||
|
bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||||
|
)
|
||||||
|
|
||||||
|
if bt.w_g_zp is None:
|
||||||
|
w_zp = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
else:
|
||||||
|
w_zp = marlin_zero_points(
|
||||||
|
bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
||||||
|
)
|
||||||
|
|
||||||
|
if bt.group_size is None:
|
||||||
|
w_s = torch.tensor([], device="cuda", dtype=torch.half)
|
||||||
|
else:
|
||||||
|
w_s = marlin_permute_scales(
|
||||||
|
bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size
|
||||||
|
)
|
||||||
|
|
||||||
|
sort_indices = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
g_idx = torch.empty(0, dtype=torch.int, device=device)
|
||||||
|
w_q = ops.gptq_marlin_repack(
|
||||||
|
bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits
|
||||||
|
)
|
||||||
|
|
||||||
|
if bt.a.dtype.is_floating_point:
|
||||||
|
assert bt.w_ch_s is None
|
||||||
|
assert bt.w_tok_s is None
|
||||||
|
assert bt.group_size is not None
|
||||||
|
|
||||||
|
fn = lambda: ops.gptq_marlin_gemm(
|
||||||
|
a=bt.a,
|
||||||
|
c=None,
|
||||||
|
b_q_weight=w_q,
|
||||||
|
b_bias=None,
|
||||||
|
b_scales=w_s,
|
||||||
|
a_scales=None,
|
||||||
|
global_scale=None,
|
||||||
|
b_zeros=w_zp,
|
||||||
|
g_idx=g_idx,
|
||||||
|
perm=sort_indices,
|
||||||
|
workspace=workspace.scratch,
|
||||||
|
b_q_type=bt.wtype,
|
||||||
|
size_m=bt.a.shape[0],
|
||||||
|
size_n=bt.w_ref.shape[1],
|
||||||
|
size_k=bt.w_ref.shape[0],
|
||||||
|
is_k_full=True,
|
||||||
|
is_zp_float=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert bt.a.dtype == torch.int8
|
||||||
|
assert bt.wtype == scalar_types.uint4b8
|
||||||
|
raise NotImplementedError("QQQ is not supported anymore")
|
||||||
|
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def machete_create_bench_fn(
|
||||||
|
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||||
|
) -> Callable:
|
||||||
|
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||||
|
w_q = ops.machete_prepack_B(
|
||||||
|
w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
w_g_zp = bt.w_g_zp
|
||||||
|
if w_g_zp is not None:
|
||||||
|
w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype))
|
||||||
|
|
||||||
|
return lambda: ops.machete_mm(
|
||||||
|
a=bt.a,
|
||||||
|
b_q=w_q,
|
||||||
|
b_type=bt.wtype,
|
||||||
|
b_group_scales=bt.w_g_s,
|
||||||
|
b_group_zeros=w_g_zp,
|
||||||
|
b_group_size=bt.group_size,
|
||||||
|
b_channel_scales=bt.w_ch_s,
|
||||||
|
a_token_scales=bt.w_tok_s,
|
||||||
|
out_type=out_type,
|
||||||
|
schedule=schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cutlass_w4a8_create_bench_fn(
|
||||||
|
bt: BenchmarkTensors, out_type=torch.dtype, schedule=None
|
||||||
|
) -> Callable:
|
||||||
|
w_q = bt.w_q.t().contiguous().t() # make col major
|
||||||
|
w_q = ops.cutlass_encode_and_reorder_int4b(w_q)
|
||||||
|
# expects fp8 scales
|
||||||
|
w_s = ops.cutlass_pack_scale_fp8(bt.w_g_s.to(torch.float8_e4m3fn))
|
||||||
|
|
||||||
|
return lambda: ops.cutlass_w4a8_mm(
|
||||||
|
a=bt.a,
|
||||||
|
b_q=w_q,
|
||||||
|
b_group_scales=w_s,
|
||||||
|
b_group_size=bt.group_size,
|
||||||
|
b_channel_scales=bt.w_ch_s,
|
||||||
|
a_token_scales=bt.w_tok_s,
|
||||||
|
maybe_schedule=schedule,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# impl
|
||||||
|
|
||||||
|
# bench
|
||||||
|
|
||||||
|
|
||||||
|
def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]):
|
||||||
|
min_run_time = 1 if not NVTX_PROFILE else 0.1
|
||||||
|
res = TBenchmark.Timer(
|
||||||
|
stmt="""
|
||||||
|
for fn in fns:
|
||||||
|
fn()
|
||||||
|
""",
|
||||||
|
globals={"fns": fns},
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
if NVTX_PROFILE:
|
||||||
|
with (
|
||||||
|
nvtx.annotate("mm-bench"),
|
||||||
|
nvtx.annotate(f"{label}|{sub_label}|{description}"),
|
||||||
|
):
|
||||||
|
fns[0]()
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
_SWEEP_SCHEDULES_RESULTS: pd.DataFrame | None = None
|
||||||
|
_SWEEP_SCHEDULES_RESULTS_CSV: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def bench(
|
||||||
|
types: TypeConfig,
|
||||||
|
group_size: int,
|
||||||
|
m: int,
|
||||||
|
k: int,
|
||||||
|
n: int,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
sweep_schedules: bool = True,
|
||||||
|
) -> list[TMeasurement]:
|
||||||
|
benchmark_tensors = create_bench_tensors((m, n, k), types, group_size)
|
||||||
|
sub_label += f", L={len(benchmark_tensors)}"
|
||||||
|
|
||||||
|
name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}"
|
||||||
|
if types.group_scale_type is not None:
|
||||||
|
name_type_string += f"-GS{terse_type_name(types.group_scale_type)}"
|
||||||
|
if types.group_zero_type is not None:
|
||||||
|
name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}"
|
||||||
|
if group_size is not None:
|
||||||
|
name_type_string += f"-G{group_size}"
|
||||||
|
if types.channel_scale_type is not None:
|
||||||
|
name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}"
|
||||||
|
if types.token_scale_type is not None:
|
||||||
|
name_type_string += f"-TS{terse_type_name(types.token_scale_type)}"
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
# pytorch impl
|
||||||
|
timers.append(
|
||||||
|
bench_fns(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"torch.matmul (fp16)",
|
||||||
|
[torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn:
|
||||||
|
timers.append(
|
||||||
|
bench_fns(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
f"cutlass_scaled_mm ({terse_type_name(types.act_type)})",
|
||||||
|
[cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if types.act_type != torch.float8_e4m3fn:
|
||||||
|
timers.append(
|
||||||
|
bench_fns(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
f"marlin ({name_type_string})",
|
||||||
|
[marlin_create_bench_fn(bt) for bt in benchmark_tensors],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# machete
|
||||||
|
timers.append(
|
||||||
|
bench_fns(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
f"machete ({name_type_string})",
|
||||||
|
[
|
||||||
|
machete_create_bench_fn(bt, out_type=types.output_type)
|
||||||
|
for bt in benchmark_tensors
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# cutlass w4a8
|
||||||
|
if types.act_type == torch.float8_e4m3fn and group_size == 128:
|
||||||
|
timers.append(
|
||||||
|
bench_fns(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
f"cutlass w4a8 ({name_type_string})",
|
||||||
|
[
|
||||||
|
cutlass_w4a8_create_bench_fn(bt, out_type=types.output_type)
|
||||||
|
for bt in benchmark_tensors
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if sweep_schedules:
|
||||||
|
global _SWEEP_SCHEDULES_RESULTS
|
||||||
|
|
||||||
|
print("Finding best schedule for machete")
|
||||||
|
best = None
|
||||||
|
best_schedule = None
|
||||||
|
schedules = ops.machete_supported_schedules(
|
||||||
|
a_type=types.act_type,
|
||||||
|
b_type=types.weight_type,
|
||||||
|
group_scales_type=types.group_scale_type,
|
||||||
|
group_zeros_type=types.group_zero_type,
|
||||||
|
token_scales_type=types.token_scale_type,
|
||||||
|
channel_scales_type=types.channel_scale_type,
|
||||||
|
out_type=types.output_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if schedules is None or len(schedules) == 0:
|
||||||
|
raise ValueError("No schedules found to sweep")
|
||||||
|
|
||||||
|
for schedule in reversed(schedules):
|
||||||
|
schedule_M = int(schedule.split("_")[0].split("x")[1])
|
||||||
|
|
||||||
|
# Prune known bad schedules
|
||||||
|
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
res = bench_fns(
|
||||||
|
label,
|
||||||
|
sub_label,
|
||||||
|
"machete_best",
|
||||||
|
[
|
||||||
|
machete_create_bench_fn(
|
||||||
|
bt, out_type=types.output_type, schedule=schedule
|
||||||
|
)
|
||||||
|
for bt in benchmark_tensors
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
results_row = {
|
||||||
|
"M": m,
|
||||||
|
"K": k,
|
||||||
|
"N": n,
|
||||||
|
"group_size": group_size,
|
||||||
|
"schedule": schedule,
|
||||||
|
"median": res.median,
|
||||||
|
}
|
||||||
|
if _SWEEP_SCHEDULES_RESULTS is None:
|
||||||
|
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys())
|
||||||
|
_SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
|
||||||
|
|
||||||
|
print(f" {res.median:5.5} ", schedule)
|
||||||
|
if not best or res.median < best.median:
|
||||||
|
best = res
|
||||||
|
best_schedule = schedule
|
||||||
|
print("Best schedule:", best_schedule)
|
||||||
|
timers.append(best)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: list[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]:
|
||||||
|
types = TypeConfig(
|
||||||
|
act_type=args.act_type,
|
||||||
|
weight_type=scalar_types.uint4b8
|
||||||
|
if args.group_zero_type is None
|
||||||
|
else scalar_types.uint4,
|
||||||
|
output_type=args.out_type,
|
||||||
|
group_scale_type=args.group_scale_type,
|
||||||
|
group_zero_type=args.group_zero_type,
|
||||||
|
channel_scale_type=args.channel_scale_type,
|
||||||
|
token_scale_type=args.token_scale_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
results: list[TMeasurement] = []
|
||||||
|
for m, k, n in MKNs:
|
||||||
|
timers = bench(
|
||||||
|
types,
|
||||||
|
args.group_size,
|
||||||
|
m,
|
||||||
|
k,
|
||||||
|
n,
|
||||||
|
f"{args.act_type}-gemm",
|
||||||
|
f"MKN=({m}x{k}x{n})",
|
||||||
|
sweep_schedules=args.sweep_schedules,
|
||||||
|
)
|
||||||
|
print_timers(timers)
|
||||||
|
results.extend(timers)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# output makers
|
||||||
|
def make_output(
|
||||||
|
data: list[TMeasurement],
|
||||||
|
MKNs: Iterable[tuple[int, int, int]],
|
||||||
|
base_description: str,
|
||||||
|
timestamp=None,
|
||||||
|
):
|
||||||
|
print(f"== All Results {base_description} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time()) if timestamp is None else timestamp
|
||||||
|
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(data, f)
|
||||||
|
|
||||||
|
|
||||||
|
# argparse runners
|
||||||
|
|
||||||
|
|
||||||
|
def run_square_bench(args):
|
||||||
|
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
|
||||||
|
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
|
||||||
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"square_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_range_bench(args):
|
||||||
|
m_start, k_start, n_start = (int(x) for x in args.dim_start.split(","))
|
||||||
|
m_end, k_end, n_end = (int(x) for x in args.dim_end.split(","))
|
||||||
|
m_increment, k_increment, n_increment = (
|
||||||
|
int(x) for x in args.dim_increment.split(",")
|
||||||
|
)
|
||||||
|
Ms = list(range(m_start, m_end + 1, m_increment))
|
||||||
|
Ks = list(range(k_start, k_end + 1, k_increment))
|
||||||
|
Ns = list(range(n_start, n_end + 1, n_increment))
|
||||||
|
MKNs = list(product(Ms, Ks, Ns))
|
||||||
|
|
||||||
|
data = run(args.dtype, args.sweep_schedules, MKNs)
|
||||||
|
|
||||||
|
make_output(data, MKNs, f"range_bench-{args.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_bench(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
|
||||||
|
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
|
||||||
|
KNs = []
|
||||||
|
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
|
||||||
|
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
|
||||||
|
KNs.append(KN)
|
||||||
|
return KNs
|
||||||
|
|
||||||
|
model_bench_data = []
|
||||||
|
models_tps = list(itertools.product(args.models, args.tp_sizes))
|
||||||
|
for model, tp_size in models_tps:
|
||||||
|
Ms = args.batch_sizes
|
||||||
|
KNs = model_shapes(model, tp_size)
|
||||||
|
MKNs = []
|
||||||
|
for m in Ms:
|
||||||
|
for k, n in KNs:
|
||||||
|
MKNs.append((m, k, n))
|
||||||
|
|
||||||
|
data = run(args, MKNs)
|
||||||
|
model_bench_data.append(data)
|
||||||
|
|
||||||
|
type_string = f"{args.act_type}"
|
||||||
|
|
||||||
|
# Print all results
|
||||||
|
for data, model_tp in zip(model_bench_data, models_tps):
|
||||||
|
model, tp_size = model_tp
|
||||||
|
print(f"== Results {type_string} {model}-TP{tp_size} ====")
|
||||||
|
print_timers(data)
|
||||||
|
|
||||||
|
timestr = time.strftime("%Y%m%d-%H%M%S")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for d in model_bench_data:
|
||||||
|
all_results.extend(d)
|
||||||
|
|
||||||
|
# pickle all data
|
||||||
|
with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f:
|
||||||
|
args_dict = vars(args)
|
||||||
|
args_dict.pop("func")
|
||||||
|
pkl.dump(
|
||||||
|
{
|
||||||
|
"args": args_dict,
|
||||||
|
"results": all_results,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
return {
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"int8": torch.int8,
|
||||||
|
"float8_e4m3fn": torch.float8_e4m3fn,
|
||||||
|
"int": torch.int,
|
||||||
|
"float": torch.float,
|
||||||
|
}[dt]
|
||||||
|
|
||||||
|
class ToTorchDtype(argparse.Action):
|
||||||
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
|
setattr(namespace, self.dest, to_torch_dtype(values))
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Benchmark Machete GEMM.
|
||||||
|
|
||||||
|
To run square GEMMs:
|
||||||
|
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
|
||||||
|
|
||||||
|
To run constant N and K and sweep M:
|
||||||
|
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
|
||||||
|
|
||||||
|
To run dimensions from a model:
|
||||||
|
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
|
||||||
|
""", # noqa: E501
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--act-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
required=True,
|
||||||
|
choices=["bfloat16", "float16", "int8", "float8_e4m3fn"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-scale-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=["bfloat16", "float16"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-zero-type",
|
||||||
|
type=to_torch_dtype,
|
||||||
|
choices=["bfloat16", "float16"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--channel-scale-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=["float"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--token-scale-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=["float"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-type",
|
||||||
|
action=ToTorchDtype,
|
||||||
|
choices=["bfloat16", "float16"],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--group-size",
|
||||||
|
type=int,
|
||||||
|
help="Available options are ['None', '-1', '128'], default=128",
|
||||||
|
default=128,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sweep-schedules",
|
||||||
|
action="store_true",
|
||||||
|
help="Run a sweep over all supported schedules",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sweep-csv-out",
|
||||||
|
help="CSV to store sweep results",
|
||||||
|
default="sch_sweep_results.csv",
|
||||||
|
)
|
||||||
|
subparsers = parser.add_subparsers(dest="cmd", required=True)
|
||||||
|
|
||||||
|
square_parser = subparsers.add_parser("square_bench")
|
||||||
|
square_parser.add_argument("--dim-start", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-end", type=int, required=True)
|
||||||
|
square_parser.add_argument("--dim-increment", type=int, required=True)
|
||||||
|
square_parser.set_defaults(func=run_square_bench)
|
||||||
|
|
||||||
|
range_parser = subparsers.add_parser("range_bench")
|
||||||
|
range_parser.add_argument(
|
||||||
|
"--dim-start",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Start value for M,K,N as common separated list",
|
||||||
|
)
|
||||||
|
range_parser.add_argument(
|
||||||
|
"--dim-end",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="End value (inclusive) for M,K,N as common separated list",
|
||||||
|
)
|
||||||
|
range_parser.add_argument(
|
||||||
|
"--dim-increment",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Increment value for M,K,N as common separated list",
|
||||||
|
)
|
||||||
|
range_parser.set_defaults(func=run_range_bench)
|
||||||
|
|
||||||
|
model_parser = subparsers.add_parser("model_bench")
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
|
||||||
|
)
|
||||||
|
model_parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
model_parser.set_defaults(func=run_model_bench)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
|
||||||
|
args.func(args)
|
||||||
|
|
||||||
|
if _SWEEP_SCHEDULES_RESULTS is not None:
|
||||||
|
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
|
||||||
413
benchmarks/kernels/benchmark_marlin.py
Normal file
413
benchmarks/kernels/benchmark_marlin.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as benchmark
|
||||||
|
from benchmark_shapes import WEIGHT_SHAPES
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||||
|
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES,
|
||||||
|
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||||
|
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||||
|
ALLSPARK_SUPPORTED_QUANT_TYPES,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL,
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||||
|
query_marlin_supported_quant_types,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||||
|
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||||
|
rand_marlin_weight_fp4_like,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
|
marlin_quant_fp8_torch,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||||
|
MarlinWorkspace,
|
||||||
|
awq_marlin_quantize,
|
||||||
|
marlin_quantize,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||||
|
marlin_24_quantize,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
gptq_pack,
|
||||||
|
gptq_quantize_weights,
|
||||||
|
quantize_weights,
|
||||||
|
sort_weights,
|
||||||
|
)
|
||||||
|
from vllm.scalar_type import ScalarType, scalar_types
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||||
|
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
|
||||||
|
|
||||||
|
ACT_ORDER_OPTS = [False, True]
|
||||||
|
K_FULL_OPTS = [False, True]
|
||||||
|
|
||||||
|
|
||||||
|
def bench_run(
|
||||||
|
results: list[benchmark.Measurement],
|
||||||
|
model: str,
|
||||||
|
act_order: bool,
|
||||||
|
is_k_full: bool,
|
||||||
|
quant_type: ScalarType,
|
||||||
|
group_size: int,
|
||||||
|
size_m: int,
|
||||||
|
size_k: int,
|
||||||
|
size_n: int,
|
||||||
|
):
|
||||||
|
label = "Quant Matmul"
|
||||||
|
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
|
||||||
|
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
|
||||||
|
)
|
||||||
|
print(f"Testing: {sub_label}")
|
||||||
|
|
||||||
|
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||||
|
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||||
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||||
|
if act_order and (group_size == -1 or group_size == size_k or has_zp):
|
||||||
|
return
|
||||||
|
if size_k % group_size != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
marlin_24_supported = (
|
||||||
|
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||||
|
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||||
|
)
|
||||||
|
repack_supported = (
|
||||||
|
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||||
|
and group_size in MARLIN_SUPPORTED_GROUP_SIZES
|
||||||
|
)
|
||||||
|
allspark_supported = (
|
||||||
|
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||||
|
and group_size == -1
|
||||||
|
and not act_order
|
||||||
|
and is_k_full
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen_marlin_params():
|
||||||
|
# Marlin quant
|
||||||
|
marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
|
||||||
|
if quant_type == scalar_types.float4_e2m1f:
|
||||||
|
if group_size != 16 or act_order:
|
||||||
|
return
|
||||||
|
marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||||
|
b.T, group_size
|
||||||
|
)
|
||||||
|
elif quant_type == scalar_types.float8_e4m3fn:
|
||||||
|
if group_size not in [-1, 128] or act_order:
|
||||||
|
return
|
||||||
|
marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
|
||||||
|
elif group_size == 16:
|
||||||
|
return
|
||||||
|
elif has_zp:
|
||||||
|
marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||||
|
b, quant_type, group_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
|
||||||
|
marlin_quantize(b, quant_type, group_size, act_order)
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
marlin_w_ref,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_s,
|
||||||
|
marlin_s2,
|
||||||
|
marlin_zp,
|
||||||
|
marlin_g_idx,
|
||||||
|
marlin_sort_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen_marlin_24_params():
|
||||||
|
marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
|
||||||
|
if marlin_24_supported:
|
||||||
|
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||||
|
marlin_24_quantize(b, quant_type, group_size)
|
||||||
|
)
|
||||||
|
return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)
|
||||||
|
|
||||||
|
def gen_repack_params():
|
||||||
|
q_w_gptq = None
|
||||||
|
repack_sort_indices = None
|
||||||
|
if repack_supported:
|
||||||
|
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||||
|
b, quant_type, group_size, act_order
|
||||||
|
)
|
||||||
|
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||||
|
|
||||||
|
# For act_order, sort the "weights" and "g_idx"
|
||||||
|
# so that group ids are increasing
|
||||||
|
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||||
|
if act_order:
|
||||||
|
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||||
|
return q_w_gptq, repack_sort_indices
|
||||||
|
|
||||||
|
def gen_allspark_params():
|
||||||
|
qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
|
||||||
|
CUBLAS_M_THRESHOLD
|
||||||
|
) = None
|
||||||
|
nonlocal allspark_supported
|
||||||
|
if allspark_supported:
|
||||||
|
properties = torch.cuda.get_device_properties(b.device.index)
|
||||||
|
sm_count = properties.multi_processor_count
|
||||||
|
sm_version = properties.major * 10 + properties.minor
|
||||||
|
|
||||||
|
supported_arch = sm_version >= 80 and sm_version < 90
|
||||||
|
allspark_supported = allspark_supported and supported_arch
|
||||||
|
if supported_arch:
|
||||||
|
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||||
|
qw = qw.to(torch.uint8)
|
||||||
|
|
||||||
|
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||||
|
qw, s, zp, has_zp
|
||||||
|
)
|
||||||
|
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||||
|
return (
|
||||||
|
qw_reorder,
|
||||||
|
s_reorder,
|
||||||
|
zp_reorder,
|
||||||
|
sm_count,
|
||||||
|
sm_version,
|
||||||
|
CUBLAS_M_THRESHOLD,
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
marlin_w_ref,
|
||||||
|
marlin_q_w,
|
||||||
|
marlin_s,
|
||||||
|
marlin_s2,
|
||||||
|
marlin_zp,
|
||||||
|
marlin_g_idx,
|
||||||
|
marlin_sort_indices,
|
||||||
|
) = gen_marlin_params()
|
||||||
|
marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
|
||||||
|
gen_marlin_24_params()
|
||||||
|
)
|
||||||
|
q_w_gptq, repack_sort_indices = gen_repack_params()
|
||||||
|
qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
|
||||||
|
gen_allspark_params()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare
|
||||||
|
marlin_workspace = MarlinWorkspace(
|
||||||
|
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||||
|
)
|
||||||
|
marlin_24_workspace = MarlinWorkspace(
|
||||||
|
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||||
|
)
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
# Gen params
|
||||||
|
"quant_type": quant_type,
|
||||||
|
"group_size": group_size,
|
||||||
|
"size_m": size_m,
|
||||||
|
"size_n": size_n,
|
||||||
|
"size_k": size_k,
|
||||||
|
"a": a,
|
||||||
|
# Marlin params
|
||||||
|
"marlin_w_ref": marlin_w_ref,
|
||||||
|
"marlin_q_w": marlin_q_w,
|
||||||
|
"marlin_s": marlin_s,
|
||||||
|
"marlin_s2": marlin_s2,
|
||||||
|
"marlin_zp": marlin_zp,
|
||||||
|
"marlin_g_idx": marlin_g_idx,
|
||||||
|
"marlin_sort_indices": marlin_sort_indices,
|
||||||
|
"marlin_workspace": marlin_workspace,
|
||||||
|
"is_k_full": is_k_full,
|
||||||
|
# Marlin_24 params
|
||||||
|
"marlin_24_w_ref": marlin_24_w_ref,
|
||||||
|
"marlin_24_q_w_comp": marlin_24_q_w_comp,
|
||||||
|
"marlin_24_meta": marlin_24_meta,
|
||||||
|
"marlin_24_s": marlin_24_s,
|
||||||
|
"marlin_24_workspace": marlin_24_workspace,
|
||||||
|
# GPTQ params
|
||||||
|
"q_w_gptq": q_w_gptq,
|
||||||
|
"repack_sort_indices": repack_sort_indices,
|
||||||
|
# AllSpark W8A16 params
|
||||||
|
"qw_reorder": qw_reorder,
|
||||||
|
"s_reorder": s_reorder,
|
||||||
|
"zp_reorder": zp_reorder,
|
||||||
|
"sm_count": sm_count,
|
||||||
|
"sm_version": sm_version,
|
||||||
|
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
|
||||||
|
# Kernels
|
||||||
|
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||||
|
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||||
|
"gptq_marlin_repack": ops.gptq_marlin_repack,
|
||||||
|
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
|
||||||
|
}
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
# Warmup pytorch
|
||||||
|
for _ in range(5):
|
||||||
|
torch.matmul(a, marlin_w_ref)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="torch.matmul(a, marlin_w_ref)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="pytorch_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_gemm_fp32",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
if marlin_24_supported:
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_24_gemm",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
if repack_supported:
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="gptq_marlin_repack",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
if allspark_supported:
|
||||||
|
results.append(
|
||||||
|
benchmark.Timer(
|
||||||
|
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description="allspark_w8a16_gemm_fp32",
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print("Benchmarking models:")
|
||||||
|
for i, model in enumerate(args.models):
|
||||||
|
print(f"[{i}] {model}")
|
||||||
|
results: list[benchmark.Measurement] = []
|
||||||
|
|
||||||
|
for model in args.models:
|
||||||
|
for layer in WEIGHT_SHAPES[model]:
|
||||||
|
size_k = layer[0]
|
||||||
|
size_n = layer[1]
|
||||||
|
|
||||||
|
if len(args.limit_k) > 0 and size_k not in args.limit_k:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(args.limit_n) > 0 and size_n not in args.limit_n:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for act_order in ACT_ORDER_OPTS:
|
||||||
|
if (
|
||||||
|
len(args.limit_act_order) > 0
|
||||||
|
and act_order not in args.limit_act_order
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for is_k_full in K_FULL_OPTS:
|
||||||
|
if (
|
||||||
|
len(args.limit_k_full) > 0
|
||||||
|
and is_k_full not in args.limit_k_full
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for quant_type in query_marlin_supported_quant_types():
|
||||||
|
if (
|
||||||
|
len(args.limit_num_bits) > 0
|
||||||
|
and quant_type.size_bits not in args.limit_num_bits
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for group_size in (
|
||||||
|
MARLIN_SUPPORTED_GROUP_SIZES
|
||||||
|
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
len(args.limit_group_size) > 0
|
||||||
|
and group_size not in args.limit_group_size
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# For act_order, the group_size must be less than
|
||||||
|
# size_k
|
||||||
|
if act_order and (group_size == size_k or group_size == -1):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for size_m in args.batch_sizes:
|
||||||
|
bench_run(
|
||||||
|
results,
|
||||||
|
model,
|
||||||
|
act_order,
|
||||||
|
is_k_full,
|
||||||
|
quant_type,
|
||||||
|
group_size,
|
||||||
|
size_m,
|
||||||
|
size_k,
|
||||||
|
size_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
compare = benchmark.Compare(results)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
# For quick benchmarking use:
|
||||||
|
# python benchmark_marlin.py --batch-sizes 1 16 32 --limit-k 4096 --limit-n 4096 --limit-group-size 128 --limit-num-bits 4 --limit-act-order 0 --limit-k-full 1 # noqa E501
|
||||||
|
#
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark Marlin across specified models/shapes/batches"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--models",
|
||||||
|
nargs="+",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_MODELS,
|
||||||
|
choices=WEIGHT_SHAPES.keys(),
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
|
||||||
|
)
|
||||||
|
parser.add_argument("--limit-k", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-n", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-group-size", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-num-bits", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-act-order", nargs="+", type=int, default=[])
|
||||||
|
parser.add_argument("--limit-k-full", nargs="+", type=int, default=[])
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
@@ -1,215 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import triton
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.fused_moe import (fused_moe,
|
|
||||||
get_config_file_name)
|
|
||||||
|
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
||||||
|
|
||||||
|
|
||||||
def main(dtype: str):
|
|
||||||
method = fused_moe
|
|
||||||
for bs in [
|
|
||||||
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
|
|
||||||
2048, 3072, 4096
|
|
||||||
]:
|
|
||||||
run_grid(bs, method=method, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def run_grid(bs, method, dtype: str):
|
|
||||||
d_model = 4096
|
|
||||||
num_total_experts = 8
|
|
||||||
top_k = 2
|
|
||||||
tp_size = 2
|
|
||||||
model_intermediate_size = 14336
|
|
||||||
num_layers = 32
|
|
||||||
num_calls = 100
|
|
||||||
|
|
||||||
num_warmup_trials = 1
|
|
||||||
num_trials = 1
|
|
||||||
|
|
||||||
configs = []
|
|
||||||
|
|
||||||
for block_size_n in [32, 64, 128, 256]:
|
|
||||||
for block_size_m in [16, 32, 64, 128, 256]:
|
|
||||||
for block_size_k in [64, 128, 256]:
|
|
||||||
for group_size_m in [1, 16, 32, 64]:
|
|
||||||
for num_warps in [4, 8]:
|
|
||||||
for num_stages in [2, 3, 4, 5]:
|
|
||||||
configs.append({
|
|
||||||
"BLOCK_SIZE_M": block_size_m,
|
|
||||||
"BLOCK_SIZE_N": block_size_n,
|
|
||||||
"BLOCK_SIZE_K": block_size_k,
|
|
||||||
"GROUP_SIZE_M": group_size_m,
|
|
||||||
"num_warps": num_warps,
|
|
||||||
"num_stages": num_stages,
|
|
||||||
})
|
|
||||||
|
|
||||||
best_config = None
|
|
||||||
best_time_us = 1e20
|
|
||||||
|
|
||||||
print(f'{tp_size=} {bs=}')
|
|
||||||
|
|
||||||
for config in tqdm(configs):
|
|
||||||
# warmup
|
|
||||||
try:
|
|
||||||
for _ in range(num_warmup_trials):
|
|
||||||
run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
except triton.runtime.autotuner.OutOfResources:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# trial
|
|
||||||
for _ in range(num_trials):
|
|
||||||
kernel_dur_ms = run_timing(
|
|
||||||
num_calls=num_calls,
|
|
||||||
bs=bs,
|
|
||||||
d_model=d_model,
|
|
||||||
num_total_experts=num_total_experts,
|
|
||||||
top_k=top_k,
|
|
||||||
tp_size=tp_size,
|
|
||||||
model_intermediate_size=model_intermediate_size,
|
|
||||||
method=method,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
kernel_dur_us = 1000 * kernel_dur_ms
|
|
||||||
model_dur_ms = kernel_dur_ms * num_layers
|
|
||||||
|
|
||||||
if kernel_dur_us < best_time_us:
|
|
||||||
best_config = config
|
|
||||||
best_time_us = kernel_dur_us
|
|
||||||
|
|
||||||
tqdm.write(
|
|
||||||
f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
|
|
||||||
f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
|
|
||||||
f'{d_model=} {model_intermediate_size=} {num_layers=}')
|
|
||||||
|
|
||||||
print("best_time_us", best_time_us)
|
|
||||||
print("best_config", best_config)
|
|
||||||
|
|
||||||
# holds Dict[str, Dict[str, int]]
|
|
||||||
filename = get_config_file_name(num_total_experts,
|
|
||||||
model_intermediate_size // tp_size,
|
|
||||||
"float8" if dtype == "float8" else None)
|
|
||||||
print(f"writing config to file {filename}")
|
|
||||||
existing_content = {}
|
|
||||||
if os.path.exists(filename):
|
|
||||||
with open(filename, "r") as f:
|
|
||||||
existing_content = json.load(f)
|
|
||||||
existing_content[str(bs)] = best_config
|
|
||||||
with open(filename, "w") as f:
|
|
||||||
json.dump(existing_content, f, indent=4)
|
|
||||||
f.write("\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_timing(num_calls: int, bs: int, d_model: int, num_total_experts: int,
|
|
||||||
top_k: int, tp_size: int, model_intermediate_size: int, method,
|
|
||||||
config, dtype: str) -> float:
|
|
||||||
shard_intermediate_size = model_intermediate_size // tp_size
|
|
||||||
|
|
||||||
hidden_states = torch.rand(
|
|
||||||
(bs, d_model),
|
|
||||||
device="cuda:0",
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1 = torch.rand(
|
|
||||||
(num_total_experts, 2 * shard_intermediate_size, d_model),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w2 = torch.rand(
|
|
||||||
(num_total_experts, d_model, shard_intermediate_size),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=hidden_states.dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
w1_scale = None
|
|
||||||
w2_scale = None
|
|
||||||
a1_scale = None
|
|
||||||
a2_scale = None
|
|
||||||
|
|
||||||
if dtype == "float8":
|
|
||||||
w1 = w1.to(torch.float8_e4m3fn)
|
|
||||||
w2 = w2.to(torch.float8_e4m3fn)
|
|
||||||
w1_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
w2_scale = torch.ones(num_total_experts,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a1_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
a2_scale = torch.ones(1,
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32)
|
|
||||||
|
|
||||||
gating_output = F.softmax(torch.rand(
|
|
||||||
(num_calls, bs, num_total_experts),
|
|
||||||
device=hidden_states.device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
),
|
|
||||||
dim=-1)
|
|
||||||
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
for i in range(num_calls):
|
|
||||||
hidden_states = method(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
w1=w1,
|
|
||||||
w2=w2,
|
|
||||||
w1_scale=w1_scale,
|
|
||||||
w2_scale=w2_scale,
|
|
||||||
a1_scale=a1_scale,
|
|
||||||
a2_scale=a2_scale,
|
|
||||||
gating_output=gating_output[i],
|
|
||||||
topk=2,
|
|
||||||
renormalize=True,
|
|
||||||
inplace=True,
|
|
||||||
override_config=config,
|
|
||||||
use_fp8=dtype == "float8",
|
|
||||||
)
|
|
||||||
end_event.record()
|
|
||||||
end_event.synchronize()
|
|
||||||
|
|
||||||
dur_ms = start_event.elapsed_time(end_event) / num_calls
|
|
||||||
return dur_ms
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
prog='benchmark_mixtral_moe',
|
|
||||||
description='Benchmark and tune the fused_moe kernel',
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default='auto',
|
|
||||||
choices=['float8', 'float16'],
|
|
||||||
help='Data type used for fused_moe kernel computations',
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
sys.exit(main(args.dtype))
|
|
||||||
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
150
benchmarks/kernels/benchmark_mla_k_concat.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Benchmark script comparing torch.cat vs direct copy for k_nope/k_pe concatenation
|
||||||
|
in MLA (Multi-head Latent Attention) prefill.
|
||||||
|
|
||||||
|
This validates that the optimization from commit 8d4142bd is beneficial across
|
||||||
|
various batch sizes, not just the originally tested batch size of 32768.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# DeepSeek-V3 MLA dimensions
|
||||||
|
NUM_HEADS = 128
|
||||||
|
QK_NOPE_HEAD_DIM = 128
|
||||||
|
PE_DIM = 64
|
||||||
|
|
||||||
|
|
||||||
|
def cat_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Original torch.cat approach with expand."""
|
||||||
|
return torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def direct_copy_method(k_nope: torch.Tensor, k_pe: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Optimized direct copy approach (avoids expand + cat overhead)."""
|
||||||
|
k = torch.empty(
|
||||||
|
(*k_nope.shape[:-1], k_nope.shape[-1] + k_pe.shape[-1]),
|
||||||
|
dtype=k_nope.dtype,
|
||||||
|
device=k_nope.device,
|
||||||
|
)
|
||||||
|
k[..., : k_nope.shape[-1]] = k_nope
|
||||||
|
k[..., k_nope.shape[-1] :] = k_pe
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_method(
|
||||||
|
method: Callable,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
num_warmup: int = 10,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> float:
|
||||||
|
"""Benchmark a concatenation method and return mean latency in ms."""
|
||||||
|
# Warmup
|
||||||
|
for _ in range(num_warmup):
|
||||||
|
_ = method(k_nope, k_pe)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(num_iters):
|
||||||
|
_ = method(k_nope, k_pe)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
|
||||||
|
return (end - start) / num_iters * 1000 # Convert to ms
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_benchmark(dtype: torch.dtype, dtype_name: str):
|
||||||
|
"""Run benchmark for a specific dtype."""
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
# Batch sizes to test (powers of 2 from 32 to 65536)
|
||||||
|
batch_sizes = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("Benchmark: torch.cat vs direct copy for MLA k_nope/k_pe concatenation")
|
||||||
|
print("=" * 80)
|
||||||
|
print(
|
||||||
|
f"Tensor shapes: k_nope=[B, {NUM_HEADS}, {QK_NOPE_HEAD_DIM}], "
|
||||||
|
f"k_pe=[B, 1, {PE_DIM}]"
|
||||||
|
)
|
||||||
|
print(f"dtype: {dtype_name}")
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
f"{'Batch Size':>12} | {'cat (ms)':>10} | {'direct (ms)':>12} | "
|
||||||
|
f"{'Speedup':>8} | {'Reduction':>10}"
|
||||||
|
)
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
# Create input tensors (generate in float32 then convert for FP8 compatibility)
|
||||||
|
k_nope = torch.randn(
|
||||||
|
batch_size, NUM_HEADS, QK_NOPE_HEAD_DIM, dtype=torch.float32, device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
k_pe = torch.randn(
|
||||||
|
batch_size, 1, PE_DIM, dtype=torch.float32, device="cuda"
|
||||||
|
).to(dtype)
|
||||||
|
|
||||||
|
# Benchmark both methods
|
||||||
|
cat_time = benchmark_method(cat_method, k_nope, k_pe)
|
||||||
|
direct_time = benchmark_method(direct_copy_method, k_nope, k_pe)
|
||||||
|
|
||||||
|
speedup = cat_time / direct_time
|
||||||
|
reduction = (1 - direct_time / cat_time) * 100
|
||||||
|
|
||||||
|
results.append((batch_size, cat_time, direct_time, speedup, reduction))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{batch_size:>12} | {cat_time:>10.3f} | {direct_time:>12.3f} | "
|
||||||
|
f"{speedup:>7.2f}x | {reduction:>9.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
speedups = [r[3] for r in results]
|
||||||
|
print("\nSpeedup summary:")
|
||||||
|
print(f" Min: {min(speedups):.2f}x")
|
||||||
|
print(f" Max: {max(speedups):.2f}x")
|
||||||
|
print(f" Mean: {sum(speedups) / len(speedups):.2f}x")
|
||||||
|
|
||||||
|
# Find crossover point
|
||||||
|
crossover_batch = None
|
||||||
|
for batch_size, _, _, speedup, _ in results:
|
||||||
|
if speedup >= 1.0:
|
||||||
|
crossover_batch = batch_size
|
||||||
|
break
|
||||||
|
|
||||||
|
print("\nConclusion:")
|
||||||
|
if crossover_batch:
|
||||||
|
print(f" - Direct copy becomes beneficial at batch size >= {crossover_batch}")
|
||||||
|
# Filter for large batches (>= 512 which is typical for prefill)
|
||||||
|
large_batch_speedups = [r[3] for r in results if r[0] >= 512]
|
||||||
|
if large_batch_speedups:
|
||||||
|
avg_large = sum(large_batch_speedups) / len(large_batch_speedups)
|
||||||
|
print(f" - For batch sizes >= 512: avg speedup = {avg_large:.2f}x")
|
||||||
|
print(" - MLA prefill typically uses large batches, so optimization is effective")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main():
|
||||||
|
# Test bfloat16
|
||||||
|
print("\n")
|
||||||
|
run_benchmark(torch.bfloat16, "bfloat16")
|
||||||
|
|
||||||
|
# Test float8_e4m3fn
|
||||||
|
print("\n")
|
||||||
|
run_benchmark(torch.float8_e4m3fn, "float8_e4m3fn")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
790
benchmarks/kernels/benchmark_moe.py
Normal file
790
benchmarks/kernels/benchmark_moe.py
Normal file
@@ -0,0 +1,790 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from datetime import datetime
|
||||||
|
from itertools import product
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
|
FusedMoEQuantConfig,
|
||||||
|
_get_config_dtype_str,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.transformers_utils.config import get_config
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_divisibility(numerator, denominator, text):
|
||||||
|
"""Ensure that numerator is divisible by the denominator."""
|
||||||
|
assert numerator % denominator == 0, "{} {} is not divisible by tp {}.".format(
|
||||||
|
text, numerator, denominator
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(TypedDict):
|
||||||
|
BLOCK_SIZE_M: int
|
||||||
|
BLOCK_SIZE_N: int
|
||||||
|
BLOCK_SIZE_K: int
|
||||||
|
GROUP_SIZE_M: int
|
||||||
|
num_warps: int
|
||||||
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
config: BenchmarkConfig,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
block_quant_shape: list[int] = None,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
|
) -> float:
|
||||||
|
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
if use_int8_w8a16:
|
||||||
|
w1 = torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
w2 = torch.randint(
|
||||||
|
-127,
|
||||||
|
127,
|
||||||
|
(
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
shard_intermediate_size // 2,
|
||||||
|
),
|
||||||
|
dtype=torch.int8,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w1 = torch.randn(
|
||||||
|
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
|
||||||
|
)
|
||||||
|
w2 = torch.randn(
|
||||||
|
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
|
||||||
|
)
|
||||||
|
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1_scale = None
|
||||||
|
w2_scale = None
|
||||||
|
a1_scale = None
|
||||||
|
a2_scale = None
|
||||||
|
if use_int8_w8a16:
|
||||||
|
w1_scale = torch.randn(
|
||||||
|
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||||
|
)
|
||||||
|
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||||
|
if use_deep_gemm:
|
||||||
|
# we use the default block shape for deepgemm
|
||||||
|
block_quant_shape = [128, 128]
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
if block_quant_shape:
|
||||||
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||||
|
E = num_experts
|
||||||
|
N = shard_intermediate_size // 2
|
||||||
|
K = hidden_size
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
n_tiles_w1 = (2 * N + block_n - 1) // block_n
|
||||||
|
n_tiles_w2 = (K + block_n - 1) // block_n
|
||||||
|
k_tiles_w1 = (K + block_k - 1) // block_k
|
||||||
|
k_tiles_w2 = (N + block_k - 1) // block_k
|
||||||
|
w1_scale = (
|
||||||
|
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
w2_scale = (
|
||||||
|
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
w1_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
w2_scale = torch.randn(num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
a1_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
a2_scale = torch.randn(1, dtype=torch.float32)
|
||||||
|
|
||||||
|
w1 = w1.to(FP8_DTYPE)
|
||||||
|
w2 = w2.to(FP8_DTYPE)
|
||||||
|
|
||||||
|
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
from vllm.model_executor.layers.fused_moe import override_config
|
||||||
|
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
quant_dtype = torch.float8_e4m3fn
|
||||||
|
elif use_int8_w8a16:
|
||||||
|
quant_dtype = torch.int8
|
||||||
|
else:
|
||||||
|
quant_dtype = None
|
||||||
|
|
||||||
|
quant_config = FusedMoEQuantConfig.make(
|
||||||
|
quant_dtype=quant_dtype,
|
||||||
|
w1_scale=w1_scale,
|
||||||
|
w2_scale=w2_scale,
|
||||||
|
a1_scale=a1_scale,
|
||||||
|
a2_scale=a2_scale,
|
||||||
|
block_shape=block_quant_shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
with override_config(config):
|
||||||
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
|
x, input_gating, topk, renormalize=not use_deep_gemm
|
||||||
|
)
|
||||||
|
return fused_experts(
|
||||||
|
x,
|
||||||
|
w1,
|
||||||
|
w2,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
quant_config=quant_config,
|
||||||
|
allow_deep_gemm=use_deep_gemm,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.Event(enable_timing=True)
|
||||||
|
end_event = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def get_rocm_tuning_space(use_fp16):
|
||||||
|
block_mn_range = [16, 32, 64, 128, 256]
|
||||||
|
block_k_range = [16, 32, 64, 128, 256]
|
||||||
|
if not use_fp16:
|
||||||
|
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
|
||||||
|
num_warps_range = [1, 2, 4, 8]
|
||||||
|
group_m_range = [1, 4, 8, 16, 32]
|
||||||
|
num_stage_range = [2]
|
||||||
|
waves_per_eu_range = [0, 1, 2, 4]
|
||||||
|
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
|
||||||
|
kpack_range = [1, 2] if use_fp16 else []
|
||||||
|
|
||||||
|
param_ranges = {
|
||||||
|
"BLOCK_SIZE_M": block_mn_range,
|
||||||
|
"BLOCK_SIZE_N": block_mn_range,
|
||||||
|
"BLOCK_SIZE_K": block_k_range,
|
||||||
|
"GROUP_SIZE_M": group_m_range,
|
||||||
|
"num_warps": num_warps_range,
|
||||||
|
"num_stages": num_stage_range,
|
||||||
|
"waves_per_eu": waves_per_eu_range,
|
||||||
|
}
|
||||||
|
if use_fp16:
|
||||||
|
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
|
||||||
|
param_ranges["kpack"] = kpack_range
|
||||||
|
|
||||||
|
return param_ranges
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]:
|
||||||
|
configs: list[BenchmarkConfig] = []
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||||
|
else:
|
||||||
|
# Reduced search space for faster tuning.
|
||||||
|
# TODO(woosuk): Increase the search space and use a performance model to
|
||||||
|
# prune the search space.
|
||||||
|
block_m_range = [16, 32, 64, 128, 256]
|
||||||
|
block_n_range = [32, 64, 128, 256]
|
||||||
|
block_k_range = [64, 128, 256]
|
||||||
|
num_warps_range = [4, 8]
|
||||||
|
group_m_range = [1, 16, 32, 64]
|
||||||
|
num_stage_range = [2, 3, 4, 5]
|
||||||
|
|
||||||
|
param_ranges = {
|
||||||
|
"BLOCK_SIZE_M": block_m_range,
|
||||||
|
"BLOCK_SIZE_N": block_n_range,
|
||||||
|
"BLOCK_SIZE_K": block_k_range,
|
||||||
|
"GROUP_SIZE_M": group_m_range,
|
||||||
|
"num_warps": num_warps_range,
|
||||||
|
"num_stages": num_stage_range,
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, values = zip(*param_ranges.items())
|
||||||
|
for config_values in product(*values):
|
||||||
|
config = dict(zip(keys, config_values))
|
||||||
|
configs.append(config)
|
||||||
|
|
||||||
|
# Remove configs that are not compatible with fp8 block quantization
|
||||||
|
# BLOCK_SIZE_K must be a multiple of block_k
|
||||||
|
# BLOCK_SIZE_N must be a multiple of block_n
|
||||||
|
if block_quant_shape is not None and not use_fp16:
|
||||||
|
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||||
|
for config in configs[:]:
|
||||||
|
if (
|
||||||
|
config["BLOCK_SIZE_K"] % block_k != 0
|
||||||
|
or config["BLOCK_SIZE_N"] % block_n != 0
|
||||||
|
):
|
||||||
|
configs.remove(config)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def prune_rocm_search_space(
|
||||||
|
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
|
||||||
|
):
|
||||||
|
N1, K1 = shard_intermediate_size, hidden_size
|
||||||
|
N2, K2 = hidden_size, shard_intermediate_size // 2
|
||||||
|
pruned_space_1 = prune_rocm_configs(
|
||||||
|
num_tokens * topk, N1, K1, search_space, is_fp16
|
||||||
|
)
|
||||||
|
pruned_space_2 = prune_rocm_configs(
|
||||||
|
num_tokens * topk, N2, K2, search_space, is_fp16
|
||||||
|
)
|
||||||
|
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
|
||||||
|
return search_space
|
||||||
|
|
||||||
|
|
||||||
|
# The following code is inspired by ROCm/Triton GEMM tuning script:
|
||||||
|
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
|
||||||
|
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
|
||||||
|
pruned_configs = []
|
||||||
|
elemBytes_a = 2 if is_fp16 else 1
|
||||||
|
elemBytes_b = 2 if is_fp16 else 1
|
||||||
|
|
||||||
|
mfma = 16 if M < 32 or N < 32 else 32
|
||||||
|
|
||||||
|
# TODO (zhanglx): figure out the boundary between large and small gemms
|
||||||
|
large_gemm = False
|
||||||
|
if M >= 2048 and N >= 2048:
|
||||||
|
large_gemm = True
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
|
||||||
|
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
|
||||||
|
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
|
||||||
|
num_warps = config.get("num_warps")
|
||||||
|
|
||||||
|
if is_fp16:
|
||||||
|
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
|
||||||
|
if matrix_instr_nonkdim > mfma:
|
||||||
|
continue
|
||||||
|
if mfma == 4 and BLOCK_SIZE_K < 64:
|
||||||
|
continue
|
||||||
|
# some layouts could not work properly in case
|
||||||
|
# number elements per thread is less 1
|
||||||
|
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
|
||||||
|
continue
|
||||||
|
SPLIT_K = config.get("SPLIT_K", 1)
|
||||||
|
GROUP_M = config.get("GROUP_SIZE_M")
|
||||||
|
if is_fp16:
|
||||||
|
if (
|
||||||
|
matrix_instr_nonkdim > BLOCK_SIZE_M
|
||||||
|
or matrix_instr_nonkdim > BLOCK_SIZE_N
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
|
||||||
|
continue
|
||||||
|
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
|
||||||
|
continue
|
||||||
|
# Skip BLOCK_SIZE that is too large compare to M/N
|
||||||
|
# unless BLOCK_SIZE is already small enough
|
||||||
|
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
|
||||||
|
continue
|
||||||
|
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
|
||||||
|
continue
|
||||||
|
# skip large split_k when not necessary
|
||||||
|
if SPLIT_K != 1 and not need_split_k(M, N, K):
|
||||||
|
continue
|
||||||
|
# skip split_k that leads to EVEN_K = false
|
||||||
|
leap = SPLIT_K * BLOCK_SIZE_K
|
||||||
|
modv = K % leap
|
||||||
|
if modv != 0:
|
||||||
|
continue
|
||||||
|
# skip large GROUP_M
|
||||||
|
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
|
||||||
|
continue
|
||||||
|
# out of shared memory resource
|
||||||
|
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
|
||||||
|
LDS = (
|
||||||
|
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
|
||||||
|
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
|
||||||
|
)
|
||||||
|
if LDS > 65536:
|
||||||
|
continue
|
||||||
|
# Skip small block sizes and num_warps for large gemm
|
||||||
|
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
|
||||||
|
if large_gemm:
|
||||||
|
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
|
||||||
|
continue
|
||||||
|
if BLOCK_SIZE_K < 64:
|
||||||
|
continue
|
||||||
|
if num_warps < 4:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pruned_configs.append(config)
|
||||||
|
|
||||||
|
return pruned_configs
|
||||||
|
|
||||||
|
|
||||||
|
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
|
||||||
|
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
|
||||||
|
|
||||||
|
|
||||||
|
def merge_unique_dicts(list1, list2):
|
||||||
|
result = []
|
||||||
|
combined_list = list1.copy()
|
||||||
|
combined_list.extend(list2)
|
||||||
|
for dictionary in combined_list:
|
||||||
|
if dictionary not in result:
|
||||||
|
result.append(dictionary)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
self.seed = seed
|
||||||
|
# Get the device ID to allocate tensors and kernels
|
||||||
|
# on the respective GPU. This is required for Ray to work
|
||||||
|
# correctly with multi-GPU tuning on the ROCm platform.
|
||||||
|
self.device_id = int(ray.get_gpu_ids()[0])
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
block_quant_shape: list[int] = None,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
|
) -> tuple[dict[str, int], float]:
|
||||||
|
current_platform.seed_everything(self.seed)
|
||||||
|
dtype_str = _get_config_dtype_str(
|
||||||
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
|
)
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
block_n = block_quant_shape[0] if block_quant_shape else None
|
||||||
|
block_k = block_quant_shape[1] if block_quant_shape else None
|
||||||
|
op_config = get_moe_configs(
|
||||||
|
num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k
|
||||||
|
)
|
||||||
|
if op_config is None:
|
||||||
|
config = get_default_config(
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype_str,
|
||||||
|
block_quant_shape,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
num_iters=100,
|
||||||
|
block_quant_shape=block_quant_shape,
|
||||||
|
use_deep_gemm=use_deep_gemm,
|
||||||
|
)
|
||||||
|
return config, kernel_time
|
||||||
|
|
||||||
|
def tune(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
search_space: list[dict[str, int]],
|
||||||
|
block_quant_shape: list[int],
|
||||||
|
use_deep_gemm: bool,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||||
|
search_space = prune_rocm_search_space(
|
||||||
|
num_tokens,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
search_space,
|
||||||
|
is_fp16,
|
||||||
|
topk,
|
||||||
|
)
|
||||||
|
|
||||||
|
need_device_guard = False
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
|
||||||
|
if visible_device != f"{self.device_id}":
|
||||||
|
need_device_guard = True
|
||||||
|
|
||||||
|
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
config,
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
num_iters=20,
|
||||||
|
block_quant_shape=block_quant_shape,
|
||||||
|
use_deep_gemm=use_deep_gemm,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||||
|
assert best_config is not None
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||||
|
return {
|
||||||
|
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
|
||||||
|
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
|
||||||
|
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
|
||||||
|
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
|
||||||
|
"num_warps": config["num_warps"],
|
||||||
|
"num_stages": config["num_stages"],
|
||||||
|
**(
|
||||||
|
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
|
||||||
|
),
|
||||||
|
**(
|
||||||
|
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
|
||||||
|
if "matrix_instr_nonkdim" in config
|
||||||
|
else {}
|
||||||
|
),
|
||||||
|
**({"kpack": config["kpack"]} if "kpack" in config else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
configs: dict[int, BenchmarkConfig],
|
||||||
|
num_experts: int,
|
||||||
|
shard_intermediate_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
block_quant_shape: list[int],
|
||||||
|
save_dir: str,
|
||||||
|
) -> None:
|
||||||
|
dtype_str = _get_config_dtype_str(
|
||||||
|
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
|
||||||
|
# is the intermediate size after silu_and_mul.
|
||||||
|
filename = get_config_file_name(
|
||||||
|
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape
|
||||||
|
)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
filename = os.path.join(save_dir, filename)
|
||||||
|
print(f"Writing best config to {filename}...")
|
||||||
|
with open(filename, "w") as f:
|
||||||
|
json.dump({"triton_version": triton.__version__, **configs}, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_block_size_safety(config, default_value=None):
|
||||||
|
quantization_config = getattr(config, "quantization_config", {})
|
||||||
|
if isinstance(quantization_config, dict):
|
||||||
|
return quantization_config.get("weight_block_size", default_value)
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
|
||||||
|
if args.model_prefix:
|
||||||
|
config = getattr(config, args.model_prefix)
|
||||||
|
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
elif config.architectures[0] in (
|
||||||
|
"DeepseekV2ForCausalLM",
|
||||||
|
"DeepseekV3ForCausalLM",
|
||||||
|
"DeepseekV32ForCausalLM",
|
||||||
|
"Glm4MoeForCausalLM",
|
||||||
|
"NemotronHForCausalLM",
|
||||||
|
):
|
||||||
|
E = config.n_routed_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
elif config.architectures[0] in (
|
||||||
|
"Qwen2MoeForCausalLM",
|
||||||
|
"Qwen3MoeForCausalLM",
|
||||||
|
"Qwen3NextForCausalLM",
|
||||||
|
):
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.moe_intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
elif config.architectures[0] == "Qwen3VLMoeForConditionalGeneration":
|
||||||
|
text_config = config.get_text_config()
|
||||||
|
E = text_config.num_experts
|
||||||
|
topk = text_config.num_experts_per_tok
|
||||||
|
intermediate_size = text_config.moe_intermediate_size
|
||||||
|
hidden_size = text_config.hidden_size
|
||||||
|
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.moe_topk[0]
|
||||||
|
intermediate_size = config.moe_intermediate_size[0]
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
|
||||||
|
E = config.thinker_config.text_config.num_experts
|
||||||
|
topk = config.thinker_config.text_config.num_experts_per_tok
|
||||||
|
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
|
||||||
|
hidden_size = config.thinker_config.text_config.hidden_size
|
||||||
|
else:
|
||||||
|
# Support for llama4
|
||||||
|
config = config.get_text_config()
|
||||||
|
# Default: Mixtral.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
intermediate_size = config.intermediate_size
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
enable_ep = bool(args.enable_expert_parallel)
|
||||||
|
if enable_ep:
|
||||||
|
ensure_divisibility(E, args.tp_size, "Number of experts")
|
||||||
|
E = E // args.tp_size
|
||||||
|
shard_intermediate_size = 2 * intermediate_size
|
||||||
|
else:
|
||||||
|
ensure_divisibility(intermediate_size, args.tp_size, "intermediate_size")
|
||||||
|
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||||
|
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||||
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
block_quant_shape = get_weight_block_size_safety(config)
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = args.batch_size
|
||||||
|
|
||||||
|
use_deep_gemm = bool(args.use_deep_gemm)
|
||||||
|
|
||||||
|
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
|
||||||
|
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
|
||||||
|
logger.warning(
|
||||||
|
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
|
||||||
|
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
|
||||||
|
)
|
||||||
|
val = os.environ["HIP_VISIBLE_DEVICES"]
|
||||||
|
os.environ["ROCR_VISIBLE_DEVICES"] = val
|
||||||
|
del os.environ["HIP_VISIBLE_DEVICES"]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
if args.tune:
|
||||||
|
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
|
||||||
|
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
|
||||||
|
print(f"Start tuning over {len(search_space)} configurations...")
|
||||||
|
if use_deep_gemm:
|
||||||
|
raise ValueError(
|
||||||
|
"Tuning with --use-deep-gemm is not supported as it only tunes Triton "
|
||||||
|
"kernels. Please remove the flag."
|
||||||
|
)
|
||||||
|
start = time.time()
|
||||||
|
configs = _distribute(
|
||||||
|
"tune",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
search_space,
|
||||||
|
block_quant_shape,
|
||||||
|
use_deep_gemm,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
best_configs = {
|
||||||
|
M: sort_config(config) for M, config in zip(batch_sizes, configs)
|
||||||
|
}
|
||||||
|
save_configs(
|
||||||
|
best_configs,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_quant_shape,
|
||||||
|
args.save_dir,
|
||||||
|
)
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning took {end - start:.2f} seconds")
|
||||||
|
else:
|
||||||
|
outputs = _distribute(
|
||||||
|
"benchmark",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
shard_intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
block_quant_shape,
|
||||||
|
use_deep_gemm,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}, config: {config}")
|
||||||
|
print(f"Kernel time: {kernel_time:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
|
||||||
|
)
|
||||||
|
parser.add_argument("--enable-expert-parallel", "-enable-ep", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-dir", type=str, default="./", help="Directory to save tuned results"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||||
|
parser.add_argument("--tune", action="store_true")
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
parser.add_argument("--model-prefix", type=str, required=False)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
87
benchmarks/kernels/benchmark_moe_align_block_size.py
Normal file
87
benchmarks/kernels/benchmark_moe_align_block_size.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import argparse
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||||
|
moe_align_block_size,
|
||||||
|
)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
|
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||||
|
return torch.stack(
|
||||||
|
[
|
||||||
|
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
|
||||||
|
for _ in range(num_tokens)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# test configurations
|
||||||
|
num_tokens_range = [1, 16, 256, 4096]
|
||||||
|
num_experts_range = [16, 64, 224, 256, 280, 512]
|
||||||
|
topk_range = [1, 2, 8]
|
||||||
|
ep_size_range = [1, 8]
|
||||||
|
configs = list(
|
||||||
|
itertools.product(num_tokens_range, num_experts_range, topk_range, ep_size_range)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["num_tokens", "num_experts", "topk", "ep_size"],
|
||||||
|
x_vals=configs,
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["vllm"],
|
||||||
|
line_names=["vLLM"],
|
||||||
|
plot_name="moe-align-block-size-performance",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(num_tokens, num_experts, topk, ep_size, provider):
|
||||||
|
"""Benchmark function for Triton."""
|
||||||
|
block_size = 256
|
||||||
|
torch.cuda.manual_seed_all(0)
|
||||||
|
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
|
||||||
|
|
||||||
|
e_map = None
|
||||||
|
if ep_size != 1:
|
||||||
|
local_e = num_experts // ep_size
|
||||||
|
e_ids = torch.randperm(num_experts, device="cuda", dtype=torch.int32)[:local_e]
|
||||||
|
e_map = torch.full((num_experts,), -1, device="cuda", dtype=torch.int32)
|
||||||
|
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "vllm":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: moe_align_block_size(
|
||||||
|
topk_ids, block_size, num_experts, e_map, ignore_invalid_experts=True
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_experts",
|
||||||
|
type=int,
|
||||||
|
default=64,
|
||||||
|
choices=[8, 16, 32, 64, 128, 256],
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--topk",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
choices=[2, 4, 8],
|
||||||
|
help="Top-k value for correctness check.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
benchmark.run(print_data=True, show_plots=True)
|
||||||
428
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
428
benchmarks/kernels/benchmark_moe_permute_unpermute.py
Normal file
@@ -0,0 +1,428 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
|
import ray
|
||||||
|
import torch
|
||||||
|
from transformers import AutoConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.fused_moe import *
|
||||||
|
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
|
||||||
|
_moe_permute,
|
||||||
|
_moe_unpermute_and_reduce,
|
||||||
|
moe_permute,
|
||||||
|
moe_unpermute,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
|
class BenchmarkConfig(TypedDict):
|
||||||
|
BLOCK_SIZE_M: int
|
||||||
|
BLOCK_SIZE_N: int
|
||||||
|
BLOCK_SIZE_K: int
|
||||||
|
GROUP_SIZE_M: int
|
||||||
|
num_warps: int
|
||||||
|
num_stages: int
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_permute(
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
use_customized_permute: bool = False,
|
||||||
|
) -> float:
|
||||||
|
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
# output_hidden_states = torch.empty_like(hidden_states)
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||||
|
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||||
|
else:
|
||||||
|
align_block_size = None
|
||||||
|
qhidden_states = hidden_states
|
||||||
|
|
||||||
|
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
|
qhidden_states, input_gating, topk, False
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare(i: int):
|
||||||
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
|
def run():
|
||||||
|
if use_customized_permute:
|
||||||
|
(
|
||||||
|
permuted_hidden_states,
|
||||||
|
a1q_scale,
|
||||||
|
first_token_off,
|
||||||
|
inv_perm_idx,
|
||||||
|
m_indices,
|
||||||
|
) = moe_permute(
|
||||||
|
qhidden_states,
|
||||||
|
a1q_scale=None,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
n_expert=num_experts,
|
||||||
|
expert_map=None,
|
||||||
|
align_block_size=align_block_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
permuted_hidden_states,
|
||||||
|
a1q_scale,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
inv_perm,
|
||||||
|
) = _moe_permute(
|
||||||
|
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.Event(enable_timing=True)
|
||||||
|
end_event = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
prepare(i)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_unpermute(
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
num_iters: int = 100,
|
||||||
|
use_customized_permute: bool = False,
|
||||||
|
) -> float:
|
||||||
|
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
|
||||||
|
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
output_hidden_states = torch.empty_like(hidden_states)
|
||||||
|
if use_fp8_w8a8:
|
||||||
|
align_block_size = 128 # deepgemm needs 128 m aligned block
|
||||||
|
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
|
||||||
|
else:
|
||||||
|
align_block_size = None
|
||||||
|
qhidden_states = hidden_states
|
||||||
|
|
||||||
|
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
|
||||||
|
|
||||||
|
topk_weights, topk_ids, token_expert_indices = fused_topk(
|
||||||
|
qhidden_states, input_gating, topk, False
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare():
|
||||||
|
if use_customized_permute:
|
||||||
|
(
|
||||||
|
permuted_hidden_states,
|
||||||
|
a1q_scale,
|
||||||
|
first_token_off,
|
||||||
|
inv_perm_idx,
|
||||||
|
m_indices,
|
||||||
|
) = moe_permute(
|
||||||
|
qhidden_states,
|
||||||
|
a1q_scale=None,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
n_expert=num_experts,
|
||||||
|
expert_map=None,
|
||||||
|
align_block_size=align_block_size,
|
||||||
|
)
|
||||||
|
# convert to fp16/bf16 as gemm output
|
||||||
|
return (
|
||||||
|
permuted_hidden_states.to(dtype),
|
||||||
|
first_token_off,
|
||||||
|
inv_perm_idx,
|
||||||
|
m_indices,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
permuted_qhidden_states,
|
||||||
|
a1q_scale,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
inv_perm,
|
||||||
|
) = _moe_permute(
|
||||||
|
qhidden_states, None, topk_ids, num_experts, None, align_block_size
|
||||||
|
)
|
||||||
|
# convert to fp16/bf16 as gemm output
|
||||||
|
return (
|
||||||
|
permuted_qhidden_states.to(dtype),
|
||||||
|
a1q_scale,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
inv_perm,
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(input: tuple):
|
||||||
|
if use_customized_permute:
|
||||||
|
(
|
||||||
|
permuted_hidden_states,
|
||||||
|
first_token_off,
|
||||||
|
inv_perm_idx,
|
||||||
|
m_indices,
|
||||||
|
) = input
|
||||||
|
output = torch.empty_like(hidden_states)
|
||||||
|
moe_unpermute(
|
||||||
|
output,
|
||||||
|
permuted_hidden_states,
|
||||||
|
topk_weights,
|
||||||
|
inv_perm_idx,
|
||||||
|
first_token_off,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
permuted_hidden_states,
|
||||||
|
a1q_scale,
|
||||||
|
sorted_token_ids,
|
||||||
|
expert_ids,
|
||||||
|
inv_perm,
|
||||||
|
) = input
|
||||||
|
_moe_unpermute_and_reduce(
|
||||||
|
output_hidden_states,
|
||||||
|
permuted_hidden_states,
|
||||||
|
inv_perm,
|
||||||
|
topk_weights,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# JIT compilation & warmup
|
||||||
|
input = prepare()
|
||||||
|
run(input)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Capture 10 invocations with CUDA graph
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
for _ in range(10):
|
||||||
|
run(input)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
for _ in range(5):
|
||||||
|
graph.replay()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.Event(enable_timing=True)
|
||||||
|
end_event = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
graph.replay()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
graph.reset()
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
@ray.remote(num_gpus=1)
|
||||||
|
class BenchmarkWorker:
|
||||||
|
def __init__(self, seed: int) -> None:
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
self.seed = seed
|
||||||
|
# Get the device ID to allocate tensors and kernels
|
||||||
|
# on the respective GPU. This is required for Ray to work
|
||||||
|
# correctly with multi-GPU tuning on the ROCm platform.
|
||||||
|
self.device_id = int(ray.get_gpu_ids()[0])
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
self,
|
||||||
|
num_tokens: int,
|
||||||
|
num_experts: int,
|
||||||
|
hidden_size: int,
|
||||||
|
topk: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
use_fp8_w8a8: bool,
|
||||||
|
use_int8_w8a16: bool,
|
||||||
|
use_customized_permute: bool = False,
|
||||||
|
) -> tuple[dict[str, int], float]:
|
||||||
|
current_platform.seed_everything(self.seed)
|
||||||
|
|
||||||
|
permute_time = benchmark_permute(
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
num_iters=100,
|
||||||
|
use_customized_permute=use_customized_permute,
|
||||||
|
)
|
||||||
|
unpermute_time = benchmark_unpermute(
|
||||||
|
num_tokens,
|
||||||
|
num_experts,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
num_iters=100,
|
||||||
|
use_customized_permute=use_customized_permute,
|
||||||
|
)
|
||||||
|
return permute_time, unpermute_time
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_block_size_safety(config, default_value=None):
|
||||||
|
quantization_config = getattr(config, "quantization_config", {})
|
||||||
|
if isinstance(quantization_config, dict):
|
||||||
|
return quantization_config.get("weight_block_size", default_value)
|
||||||
|
return default_value
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
args.model, trust_remote_code=args.trust_remote_code
|
||||||
|
)
|
||||||
|
if config.architectures[0] == "DbrxForCausalLM":
|
||||||
|
E = config.ffn_config.moe_num_experts
|
||||||
|
topk = config.ffn_config.moe_top_k
|
||||||
|
elif config.architectures[0] == "JambaForCausalLM":
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
elif (
|
||||||
|
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||||
|
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||||
|
or config.architectures[0] == "Glm4MoeForCausalLM"
|
||||||
|
):
|
||||||
|
E = config.n_routed_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]:
|
||||||
|
E = config.num_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Support for llama4
|
||||||
|
config = config.get_text_config()
|
||||||
|
# Default: Mixtral.
|
||||||
|
E = config.num_local_experts
|
||||||
|
topk = config.num_experts_per_tok
|
||||||
|
|
||||||
|
hidden_size = config.hidden_size
|
||||||
|
dtype = torch.float16 if current_platform.is_rocm() else config.dtype
|
||||||
|
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
|
||||||
|
use_int8_w8a16 = args.dtype == "int8_w8a16"
|
||||||
|
use_customized_permute = args.use_customized_permute
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
|
||||||
|
ray.init()
|
||||||
|
num_gpus = int(ray.available_resources()["GPU"])
|
||||||
|
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||||
|
|
||||||
|
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||||
|
outputs = []
|
||||||
|
worker_idx = 0
|
||||||
|
for input_args in inputs:
|
||||||
|
worker = workers[worker_idx]
|
||||||
|
worker_method = getattr(worker, method)
|
||||||
|
output = worker_method.remote(*input_args)
|
||||||
|
outputs.append(output)
|
||||||
|
worker_idx = (worker_idx + 1) % num_gpus
|
||||||
|
return ray.get(outputs)
|
||||||
|
|
||||||
|
outputs = _distribute(
|
||||||
|
"benchmark",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
E,
|
||||||
|
hidden_size,
|
||||||
|
topk,
|
||||||
|
dtype,
|
||||||
|
use_fp8_w8a8,
|
||||||
|
use_int8_w8a16,
|
||||||
|
use_customized_permute,
|
||||||
|
)
|
||||||
|
for batch_size in batch_sizes
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
|
||||||
|
print(f"Batch size: {batch_size}")
|
||||||
|
print(f"Permute time: {permute:.2f} us")
|
||||||
|
print(f"Unpermute time: {unpermute:.2f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
|
||||||
|
)
|
||||||
|
parser.add_argument("--use-customized-permute", action="store_true")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
322
benchmarks/kernels/benchmark_mrope.py
Normal file
322
benchmarks/kernels/benchmark_mrope.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# This script benchmarks the mrope kernel (mainly for Qwen2VL and Qwen2.5VL models).
|
||||||
|
# It generates test data, runs benchmarks, and saves results to a CSV file.
|
||||||
|
#
|
||||||
|
# The CSV file (named with current date/time) contains these columns:
|
||||||
|
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
|
||||||
|
# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
|
||||||
|
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
|
||||||
|
# speedup
|
||||||
|
#
|
||||||
|
# == Usage Examples ==
|
||||||
|
#
|
||||||
|
# Single model benchmark:
|
||||||
|
# python3 benchmark_mrope.py --model-name Qwen/Qwen2-VL-7B-Instruct --tp-size 1 \
|
||||||
|
# --warmup-iter 10 --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||||
|
#
|
||||||
|
# All models benchmark:
|
||||||
|
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||||
|
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||||
|
#
|
||||||
|
# All models with different TP sizes:
|
||||||
|
# python3 benchmark_mrope.py --model-name "" --tp-size 1 2 4 8 --warmup-iter 10 \
|
||||||
|
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024
|
||||||
|
#
|
||||||
|
# All models with different token counts:
|
||||||
|
# python3 benchmark_mrope.py --model-name "" --tp-size 1 --warmup-iter 10 \
|
||||||
|
# --benchmark-iter 100 --dtype bfloat16 --seed 0 --num-tokens 1024 4096 16384
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.transformers_utils.config import get_config
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_test_data(
|
||||||
|
num_tokens: int,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
"""Generate test data for given configuration."""
|
||||||
|
# Create 2D positions (3, num_tokens) for multimodal case
|
||||||
|
positions = torch.randint(
|
||||||
|
0, max_position_embeddings // 4, (3, num_tokens), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create query and key tensors
|
||||||
|
query = torch.randn(num_tokens, num_q_heads * head_size, dtype=dtype, device=device)
|
||||||
|
key = torch.randn(num_tokens, num_kv_heads * head_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
return positions, query, key
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_stats(times: list[float]) -> dict[str, float]:
|
||||||
|
"""Calculate statistics from a list of times."""
|
||||||
|
times_array = np.array(times)
|
||||||
|
return {
|
||||||
|
"mean": np.mean(times_array),
|
||||||
|
"median": np.median(times_array),
|
||||||
|
"p99": np.percentile(times_array, 99),
|
||||||
|
"min": np.min(times_array),
|
||||||
|
"max": np.max(times_array),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_mrope(
|
||||||
|
model_name: str,
|
||||||
|
num_tokens: int,
|
||||||
|
head_dim: int,
|
||||||
|
tp_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position: int = 8192,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
rope_parameters: dict[str, Any] | None = None,
|
||||||
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
seed: int = 0,
|
||||||
|
warmup_iter: int = 10,
|
||||||
|
benchmark_iter: int = 100,
|
||||||
|
csv_writer=None,
|
||||||
|
):
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
# the parameters to compute the q k v size based on tp_size
|
||||||
|
mrope_helper_class = get_rope(
|
||||||
|
head_size=head_dim,
|
||||||
|
max_position=max_position,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
rope_parameters=rope_parameters,
|
||||||
|
dtype=dtype,
|
||||||
|
).to(device=device)
|
||||||
|
|
||||||
|
print(80 * "=")
|
||||||
|
print(
|
||||||
|
f"Evaluating model: {model_name} "
|
||||||
|
f"with tp_size: {tp_size} "
|
||||||
|
f"and num_tokens: {num_tokens}, "
|
||||||
|
f"dtype: {dtype}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# create q k v input tensors
|
||||||
|
# create rotary pos emb input tensors
|
||||||
|
positions, query, key = generate_test_data(
|
||||||
|
num_tokens, num_heads, num_kv_heads, head_dim, max_position, dtype, device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warm up
|
||||||
|
for _ in range(warmup_iter):
|
||||||
|
mrope_helper_class.forward_native(
|
||||||
|
positions,
|
||||||
|
query.clone(),
|
||||||
|
key.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
mrope_helper_class.forward_cuda(
|
||||||
|
positions,
|
||||||
|
query.clone(),
|
||||||
|
key.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Time reference implementation
|
||||||
|
torch_times = []
|
||||||
|
for _ in range(benchmark_iter):
|
||||||
|
query_clone = query.clone()
|
||||||
|
key_clone = key.clone()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
mrope_helper_class.forward_native(
|
||||||
|
positions,
|
||||||
|
query_clone,
|
||||||
|
key_clone,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Time triton kernel implementation
|
||||||
|
triton_times = []
|
||||||
|
for _ in range(benchmark_iter):
|
||||||
|
query_clone = query.clone()
|
||||||
|
key_clone = key.clone()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_time = time.time()
|
||||||
|
mrope_helper_class.forward_cuda(
|
||||||
|
positions,
|
||||||
|
query_clone,
|
||||||
|
key_clone,
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
triton_times.append(time.time() - start_time)
|
||||||
|
|
||||||
|
# Calculate statistics
|
||||||
|
torch_stats = calculate_stats(torch_times)
|
||||||
|
triton_stats = calculate_stats(triton_times)
|
||||||
|
print(f"\nPerformance for config ({num_tokens}, {num_heads}, {num_kv_heads}):")
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Torch implementation: "
|
||||||
|
f"mean={torch_stats['mean']:.8f}s, "
|
||||||
|
f"median={torch_stats['median']:.8f}s, "
|
||||||
|
f"p99={torch_stats['p99']:.8f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Triton implementation: "
|
||||||
|
f"mean={triton_stats['mean']:.8f}s, "
|
||||||
|
f"median={triton_stats['median']:.8f}s, "
|
||||||
|
f"p99={triton_stats['p99']:.8f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Triton Speedup over Torch: {torch_stats['mean'] / triton_stats['mean']:.8f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write to CSV
|
||||||
|
if csv_writer:
|
||||||
|
row = [
|
||||||
|
model_name,
|
||||||
|
tp_size,
|
||||||
|
num_tokens,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
max_position,
|
||||||
|
is_neox_style,
|
||||||
|
str(rope_parameters),
|
||||||
|
str(dtype).split(".")[-1],
|
||||||
|
torch_stats["mean"],
|
||||||
|
torch_stats["median"],
|
||||||
|
torch_stats["p99"],
|
||||||
|
torch_stats["min"],
|
||||||
|
torch_stats["max"],
|
||||||
|
triton_stats["mean"],
|
||||||
|
triton_stats["median"],
|
||||||
|
triton_stats["p99"],
|
||||||
|
triton_stats["min"],
|
||||||
|
triton_stats["max"],
|
||||||
|
torch_stats["mean"] / triton_stats["mean"], # speedup
|
||||||
|
]
|
||||||
|
csv_writer.writerow(row)
|
||||||
|
|
||||||
|
return torch_stats, triton_stats
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the rotary embedding kernels."
|
||||||
|
)
|
||||||
|
parser.add_argument("--model-name", type=str, default="")
|
||||||
|
parser.add_argument("--tp-size", type=int, default=1)
|
||||||
|
parser.add_argument("--warmup-iter", type=int, default=10)
|
||||||
|
parser.add_argument("--benchmark-iter", type=int, default=100)
|
||||||
|
parser.add_argument("--dtype", type=str, choices=["bfloat16"], default="bfloat16")
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--num-tokens", type=int, nargs="+", required=False)
|
||||||
|
parser.add_argument("--trust-remote-code", action="store_true")
|
||||||
|
parser.add_argument("--output-csv", type=str, default="mrope_benchmark_results.csv")
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
# Create CSV file for results
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
csv_filename = f"{os.path.splitext(args.output_csv)[0]}_{timestamp}.csv"
|
||||||
|
|
||||||
|
with open(csv_filename, "w", newline="") as csvfile:
|
||||||
|
csv_writer = csv.writer(csvfile)
|
||||||
|
# Write header
|
||||||
|
header = [
|
||||||
|
"model_name",
|
||||||
|
"tp_size",
|
||||||
|
"num_tokens",
|
||||||
|
"num_heads",
|
||||||
|
"num_kv_heads",
|
||||||
|
"head_dim",
|
||||||
|
"max_position",
|
||||||
|
"is_neox_style",
|
||||||
|
"rope_parameters",
|
||||||
|
"dtype",
|
||||||
|
"torch_mean",
|
||||||
|
"torch_median",
|
||||||
|
"torch_p99",
|
||||||
|
"torch_min",
|
||||||
|
"torch_max",
|
||||||
|
"triton_mean",
|
||||||
|
"triton_median",
|
||||||
|
"triton_p99",
|
||||||
|
"triton_min",
|
||||||
|
"triton_max",
|
||||||
|
"speedup",
|
||||||
|
]
|
||||||
|
csv_writer.writerow(header)
|
||||||
|
|
||||||
|
model_tp_dict = {}
|
||||||
|
if args.model_name == "":
|
||||||
|
model_tp_dict = {
|
||||||
|
"Qwen/Qwen2-VL-2B-Instruct": [1],
|
||||||
|
"Qwen/Qwen2-VL-7B-Instruct": [1],
|
||||||
|
"Qwen/Qwen2-VL-72B-Instruct": [2, 4, 8],
|
||||||
|
"Qwen/Qwen2.5-VL-3B-Instruct": [1, 2, 4, 8],
|
||||||
|
"Qwen/Qwen2.5-VL-7B-Instruct": [1, 2, 4, 8],
|
||||||
|
"Qwen/Qwen2.5-VL-72B-Instruct": [2, 4, 8],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
model_tp_dict[args.model_name] = [args.tp_size]
|
||||||
|
|
||||||
|
if args.num_tokens is None:
|
||||||
|
num_tokens_list = [2**i for i in range(0, 18)]
|
||||||
|
else:
|
||||||
|
num_tokens_list = args.num_tokens
|
||||||
|
|
||||||
|
for model_name, tp_list in model_tp_dict.items():
|
||||||
|
config = get_config(model_name, trust_remote_code=args.trust_remote_code)
|
||||||
|
for tp_size in tp_list:
|
||||||
|
# get the model config
|
||||||
|
total_num_kv_heads = config.num_key_value_heads
|
||||||
|
total_num_heads = config.num_attention_heads
|
||||||
|
num_heads = total_num_heads // tp_size
|
||||||
|
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||||
|
head_dim = config.hidden_size // total_num_heads
|
||||||
|
q_size = num_heads * head_dim
|
||||||
|
kv_size = num_kv_heads * head_dim
|
||||||
|
is_neox_style = True
|
||||||
|
rope_parameters = config.rope_parameters
|
||||||
|
max_position = config.max_position_embeddings
|
||||||
|
|
||||||
|
for num_tokens in num_tokens_list:
|
||||||
|
benchmark_mrope(
|
||||||
|
model_name=model_name,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
head_dim=head_dim,
|
||||||
|
tp_size=tp_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
max_position=max_position,
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
rope_parameters=rope_parameters,
|
||||||
|
dtype=getattr(torch, args.dtype),
|
||||||
|
seed=args.seed,
|
||||||
|
warmup_iter=args.warmup_iter,
|
||||||
|
benchmark_iter=args.benchmark_iter,
|
||||||
|
csv_writer=csv_writer,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Benchmark results saved to {csv_filename}")
|
||||||
@@ -1,15 +1,25 @@
|
|||||||
import argparse
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import (
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
create_kv_caches_with_random,
|
||||||
|
)
|
||||||
|
|
||||||
NUM_BLOCKS = 1024
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
NUM_BLOCKS = 128 * 1024
|
||||||
PARTITION_SIZE = 512
|
PARTITION_SIZE = 512
|
||||||
|
PARTITION_SIZE_ROCM = 256
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@@ -26,27 +36,20 @@ def main(
|
|||||||
seed: int,
|
seed: int,
|
||||||
do_profile: bool,
|
do_profile: bool,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
kv_cache_dtype: Optional[str] = None,
|
kv_cache_dtype: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
random.seed(seed)
|
current_platform.seed_everything(seed)
|
||||||
torch.random.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
|
|
||||||
scale = float(1.0 / (head_size**0.5))
|
scale = float(1.0 / (head_size**0.5))
|
||||||
query = torch.empty(num_seqs,
|
query = torch.empty(
|
||||||
num_query_heads,
|
num_seqs, num_query_heads, head_size, dtype=dtype, device=device
|
||||||
head_size,
|
)
|
||||||
dtype=dtype,
|
|
||||||
device=device)
|
|
||||||
query.uniform_(-scale, scale)
|
query.uniform_(-scale, scale)
|
||||||
|
|
||||||
assert num_query_heads % num_kv_heads == 0
|
assert num_query_heads % num_kv_heads == 0
|
||||||
alibi_slopes = None
|
alibi_slopes = None
|
||||||
if use_alibi:
|
if use_alibi:
|
||||||
alibi_slopes = torch.randn(num_query_heads,
|
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
|
||||||
dtype=torch.float,
|
|
||||||
device=device)
|
|
||||||
|
|
||||||
seq_lens = [seq_len for _ in range(num_seqs)]
|
seq_lens = [seq_len for _ in range(num_seqs)]
|
||||||
max_seq_len = max(seq_lens)
|
max_seq_len = max(seq_lens)
|
||||||
@@ -54,30 +57,38 @@ def main(
|
|||||||
|
|
||||||
# Create the block tables.
|
# Create the block tables.
|
||||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
block_tables = []
|
block_tables_lst: list[list[int]] = []
|
||||||
for _ in range(num_seqs):
|
for _ in range(num_seqs):
|
||||||
block_table = [
|
block_table = [
|
||||||
random.randint(0, NUM_BLOCKS - 1)
|
random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
|
||||||
for _ in range(max_num_blocks_per_seq)
|
|
||||||
]
|
]
|
||||||
block_tables.append(block_table)
|
block_tables_lst.append(block_table)
|
||||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
|
||||||
|
block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
|
||||||
|
|
||||||
# Create the KV cache.
|
# Create the KV cache.
|
||||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
key_caches, value_caches = create_kv_caches_with_random(
|
||||||
block_size,
|
NUM_BLOCKS,
|
||||||
1,
|
block_size,
|
||||||
num_kv_heads,
|
1,
|
||||||
head_size,
|
num_kv_heads,
|
||||||
kv_cache_dtype,
|
head_size,
|
||||||
dtype,
|
kv_cache_dtype,
|
||||||
device=device)
|
dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
|
||||||
# Prepare for the paged attention kernel.
|
# Prepare for the paged attention kernel.
|
||||||
output = torch.empty_like(query)
|
output = torch.empty_like(query)
|
||||||
if version == "v2":
|
if version == "v2":
|
||||||
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
|
if current_platform.is_rocm():
|
||||||
|
global PARTITION_SIZE
|
||||||
|
if not args.custom_paged_attn and not current_platform.is_navi():
|
||||||
|
PARTITION_SIZE = 1024
|
||||||
|
else:
|
||||||
|
PARTITION_SIZE = PARTITION_SIZE_ROCM
|
||||||
|
num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
|
||||||
tmp_output = torch.empty(
|
tmp_output = torch.empty(
|
||||||
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
size=(num_seqs, num_query_heads, num_partitions, head_size),
|
||||||
dtype=output.dtype,
|
dtype=output.dtype,
|
||||||
@@ -97,7 +108,7 @@ def main(
|
|||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
# Using default kv_scale
|
# Using default kv_scale
|
||||||
kv_scale = 1.0
|
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for _ in range(num_iters):
|
for _ in range(num_iters):
|
||||||
if version == "v1":
|
if version == "v1":
|
||||||
@@ -114,34 +125,58 @@ def main(
|
|||||||
max_seq_len,
|
max_seq_len,
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
kv_cache_dtype,
|
kv_cache_dtype,
|
||||||
kv_scale,
|
k_scale,
|
||||||
|
v_scale,
|
||||||
)
|
)
|
||||||
elif version == "v2":
|
elif version == "v2":
|
||||||
ops.paged_attention_v2(
|
if not args.custom_paged_attn:
|
||||||
output,
|
ops.paged_attention_v2(
|
||||||
exp_sums,
|
output,
|
||||||
max_logits,
|
exp_sums,
|
||||||
tmp_output,
|
max_logits,
|
||||||
query,
|
tmp_output,
|
||||||
key_cache,
|
query,
|
||||||
value_cache,
|
key_cache,
|
||||||
num_kv_heads,
|
value_cache,
|
||||||
scale,
|
num_kv_heads,
|
||||||
block_tables,
|
scale,
|
||||||
seq_lens,
|
block_tables,
|
||||||
block_size,
|
seq_lens,
|
||||||
max_seq_len,
|
block_size,
|
||||||
alibi_slopes,
|
max_seq_len,
|
||||||
kv_cache_dtype,
|
alibi_slopes,
|
||||||
kv_scale,
|
kv_cache_dtype,
|
||||||
)
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ops.paged_attention_rocm(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
num_kv_heads,
|
||||||
|
scale,
|
||||||
|
block_tables,
|
||||||
|
seq_lens,
|
||||||
|
None,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
alibi_slopes,
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid version: {version}")
|
raise ValueError(f"Invalid version: {version}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
if profile:
|
if profile:
|
||||||
torch.cuda.cudart().cudaProfilerStart()
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
return (end_time - start_time) / num_iters
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
# Warmup.
|
# Warmup.
|
||||||
@@ -157,39 +192,43 @@ def main(
|
|||||||
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
logger.warning(
|
||||||
description="Benchmark the paged attention kernel.")
|
"This script benchmarks the paged attention kernel. "
|
||||||
parser.add_argument("--version",
|
"By default this is no longer used in vLLM inference."
|
||||||
type=str,
|
)
|
||||||
choices=["v1", "v2"],
|
|
||||||
default="v2")
|
parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
|
||||||
|
parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
|
||||||
parser.add_argument("--batch-size", type=int, default=8)
|
parser.add_argument("--batch-size", type=int, default=8)
|
||||||
parser.add_argument("--seq_len", type=int, default=4096)
|
parser.add_argument("--seq-len", type=int, default=4096)
|
||||||
parser.add_argument("--num-query-heads", type=int, default=64)
|
parser.add_argument("--num-query-heads", type=int, default=64)
|
||||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument(
|
||||||
type=int,
|
"--head-size",
|
||||||
choices=[64, 80, 96, 112, 128, 256],
|
type=int,
|
||||||
default=128)
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
|
default=128,
|
||||||
|
)
|
||||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
parser.add_argument("--use-alibi", action="store_true")
|
parser.add_argument("--use-alibi", action="store_true")
|
||||||
parser.add_argument("--dtype",
|
parser.add_argument(
|
||||||
type=str,
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||||
choices=["half", "bfloat16", "float"],
|
)
|
||||||
default="half")
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--profile", action="store_true")
|
parser.add_argument("--profile", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--kv-cache-dtype",
|
"--kv-cache-dtype",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["auto", "fp8"],
|
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||||
default="auto",
|
default="auto",
|
||||||
help=
|
help="Data type for kv cache storage. If 'auto', will use model "
|
||||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
|
||||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
)
|
||||||
'common inference criteria.')
|
parser.add_argument(
|
||||||
|
"--custom-paged-attn", action="store_true", help="Use custom paged attention"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
print(args)
|
||||||
|
|
||||||
|
|||||||
159
benchmarks/kernels/benchmark_per_token_group_quant.py
Normal file
159
benchmarks/kernels/benchmark_per_token_group_quant.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import math
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils import fp8_utils, int8_utils
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _triton_mode():
|
||||||
|
"""Temporarily force the Triton fallback path"""
|
||||||
|
with patch("vllm.platforms.current_platform.is_cuda", return_value=False):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def _time_cuda(
|
||||||
|
fn: Callable[[], tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
warmup_iters: int,
|
||||||
|
bench_iters: int,
|
||||||
|
) -> float:
|
||||||
|
# warmup
|
||||||
|
for _ in range(warmup_iters):
|
||||||
|
fn()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start = torch.Event(enable_timing=True)
|
||||||
|
end = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
start.record()
|
||||||
|
for _ in range(bench_iters):
|
||||||
|
fn()
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return start.elapsed_time(end) / bench_iters # ms/iter
|
||||||
|
|
||||||
|
|
||||||
|
def _run_single(
|
||||||
|
shape: tuple[int, int],
|
||||||
|
group_size: int,
|
||||||
|
dtype: str,
|
||||||
|
*,
|
||||||
|
column_major: bool = False,
|
||||||
|
scale_ue8m0: bool = False,
|
||||||
|
warmup_iters: int,
|
||||||
|
bench_iters: int,
|
||||||
|
) -> None:
|
||||||
|
num_tokens, hidden_dim = shape
|
||||||
|
|
||||||
|
device = torch.device("cuda")
|
||||||
|
torch.manual_seed(42)
|
||||||
|
x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16) * 8
|
||||||
|
|
||||||
|
if dtype == "fp8":
|
||||||
|
|
||||||
|
def cuda_impl():
|
||||||
|
return fp8_utils.per_token_group_quant_fp8(
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
column_major_scales=column_major,
|
||||||
|
use_ue8m0=scale_ue8m0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def triton_impl():
|
||||||
|
with _triton_mode():
|
||||||
|
return fp8_utils.per_token_group_quant_fp8(
|
||||||
|
x,
|
||||||
|
group_size,
|
||||||
|
column_major_scales=column_major,
|
||||||
|
use_ue8m0=scale_ue8m0,
|
||||||
|
)
|
||||||
|
elif dtype == "int8":
|
||||||
|
|
||||||
|
def cuda_impl():
|
||||||
|
return int8_utils.per_token_group_quant_int8(x, group_size)
|
||||||
|
|
||||||
|
def triton_impl():
|
||||||
|
with _triton_mode():
|
||||||
|
return int8_utils.per_token_group_quant_int8(x, group_size)
|
||||||
|
else:
|
||||||
|
raise ValueError("dtype must be 'fp8' or 'int8'")
|
||||||
|
|
||||||
|
cuda_ms = _time_cuda(cuda_impl, warmup_iters, bench_iters)
|
||||||
|
triton_ms = _time_cuda(triton_impl, warmup_iters, bench_iters)
|
||||||
|
|
||||||
|
speedup = triton_ms / cuda_ms if cuda_ms else math.inf
|
||||||
|
|
||||||
|
cfg_desc = (
|
||||||
|
f"shape={shape} gs={group_size:<3} col_major={column_major:<5} "
|
||||||
|
f"ue8m0={scale_ue8m0:<5} dtype={dtype}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"{cfg_desc:55} | CUDA {cuda_ms:7.3f} ms | Triton {triton_ms:7.3f} ms | "
|
||||||
|
f"speed-up ×{speedup:5.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--warmup-iters", type=int, default=10)
|
||||||
|
parser.add_argument("--bench-iters", type=int, default=100)
|
||||||
|
parser.add_argument("--dtype", choices=["fp8", "int8", "both"], default="both")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if not current_platform.is_cuda():
|
||||||
|
raise RuntimeError("CUDA device is required to run this benchmark.")
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
warmup_iters, bench_iters = args.warmup_iters, args.bench_iters
|
||||||
|
|
||||||
|
shapes = [(32, 128), (64, 256), (16, 512)]
|
||||||
|
group_sizes = [64, 128]
|
||||||
|
|
||||||
|
dtypes = ["fp8", "int8"] if args.dtype == "both" else [args.dtype]
|
||||||
|
|
||||||
|
header = (
|
||||||
|
"Configuration".ljust(55)
|
||||||
|
+ " | "
|
||||||
|
+ "CUDA (ms)".center(12)
|
||||||
|
+ " | "
|
||||||
|
+ "Triton (ms)".center(13)
|
||||||
|
+ " | "
|
||||||
|
+ "Speed-up"
|
||||||
|
)
|
||||||
|
print(header)
|
||||||
|
print("-" * len(header))
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
for shape in shapes:
|
||||||
|
for gs in group_sizes:
|
||||||
|
if dtype == "fp8":
|
||||||
|
for col_major in (False, True):
|
||||||
|
for ue8m0 in (False, True):
|
||||||
|
_run_single(
|
||||||
|
shape,
|
||||||
|
gs,
|
||||||
|
dtype,
|
||||||
|
column_major=col_major,
|
||||||
|
scale_ue8m0=ue8m0,
|
||||||
|
warmup_iters=warmup_iters,
|
||||||
|
bench_iters=bench_iters,
|
||||||
|
)
|
||||||
|
else: # INT8 has no col-major / ue8m0 switches
|
||||||
|
_run_single(
|
||||||
|
shape,
|
||||||
|
gs,
|
||||||
|
dtype,
|
||||||
|
warmup_iters=warmup_iters,
|
||||||
|
bench_iters=bench_iters,
|
||||||
|
)
|
||||||
109
benchmarks/kernels/benchmark_quant.py
Normal file
109
benchmarks/kernels/benchmark_quant.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def main(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
static_scale: bool,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seed: int = 0,
|
||||||
|
do_profile: bool = False,
|
||||||
|
num_warmup_iters: int = 5,
|
||||||
|
num_iters: int = 100,
|
||||||
|
) -> None:
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
|
||||||
|
|
||||||
|
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
for _ in range(num_iters):
|
||||||
|
if quant_dtype == torch.int8:
|
||||||
|
ops.scaled_int8_quant(x, scale)
|
||||||
|
else:
|
||||||
|
ops.scaled_fp8_quant(x, scale)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
end_time = time.perf_counter()
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
return (end_time - start_time) / num_iters
|
||||||
|
|
||||||
|
# Warmup.
|
||||||
|
print("Warming up...")
|
||||||
|
run_benchmark = run_cuda_benchmark
|
||||||
|
run_benchmark(num_iters=num_warmup_iters, profile=False)
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
if do_profile:
|
||||||
|
latency = run_benchmark(num_iters=1, profile=True)
|
||||||
|
else:
|
||||||
|
latency = run_benchmark(num_iters=num_iters, profile=False)
|
||||||
|
print(f"Kernel running time: {latency * 1000000:.3f} us")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
def to_torch_dtype(dt):
|
||||||
|
if dt == "int8":
|
||||||
|
return torch.int8
|
||||||
|
if dt == "fp8":
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
raise ValueError(f"Unsupported dtype: {dt}")
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the quantization (fp8 or int8) kernel."
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-tokens", type=int, default=4096)
|
||||||
|
parser.add_argument("--hidden-size", type=int, default=8192)
|
||||||
|
parser.add_argument("--static-scale", action="store_true")
|
||||||
|
parser.add_argument(
|
||||||
|
"--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--profile", action="store_true")
|
||||||
|
parser.add_argument("--num-warmup-iters", type=int, default=5)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-iters",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Number of benchmark iterations. "
|
||||||
|
"If --profile is set, this number is ignored",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
main(
|
||||||
|
num_tokens=args.num_tokens,
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
static_scale=args.static_scale,
|
||||||
|
quant_dtype=to_torch_dtype(args.quant_dtype),
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
seed=args.seed,
|
||||||
|
do_profile=args.profile,
|
||||||
|
num_warmup_iters=args.num_warmup_iters,
|
||||||
|
num_iters=args.num_iters,
|
||||||
|
)
|
||||||
172
benchmarks/kernels/benchmark_reshape_and_cache.py
Normal file
172
benchmarks/kernels/benchmark_reshape_and_cache.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import (
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
create_kv_caches_with_random,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_benchmark(
|
||||||
|
num_tokens: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
num_iters: int,
|
||||||
|
benchmark_mode: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> float:
|
||||||
|
"""Return latency (seconds) for given num_tokens."""
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||||
|
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
# create random key / value tensors [T, H, D].
|
||||||
|
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
||||||
|
value = torch.randn_like(key)
|
||||||
|
|
||||||
|
# prepare the slot mapping.
|
||||||
|
# each token is assigned a unique slot in the KV-cache.
|
||||||
|
num_slots = block_size * num_blocks
|
||||||
|
if num_tokens > num_slots:
|
||||||
|
raise ValueError("num_tokens cannot exceed the total number of cache slots")
|
||||||
|
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
key_caches, value_caches = create_kv_caches_with_random(
|
||||||
|
num_blocks,
|
||||||
|
block_size,
|
||||||
|
1, # num_layers
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
kv_cache_dtype,
|
||||||
|
dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
# to free unused memory
|
||||||
|
del key_caches, value_caches
|
||||||
|
|
||||||
|
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||||
|
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||||
|
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||||
|
|
||||||
|
function_under_test = lambda: ops.reshape_and_cache(
|
||||||
|
key, # noqa: F821
|
||||||
|
value, # noqa: F821
|
||||||
|
key_cache, # noqa: F821
|
||||||
|
value_cache, # noqa: F821
|
||||||
|
slot_mapping, # noqa: F821
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if benchmark_mode == "cudagraph":
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
function_under_test = lambda: g.replay()
|
||||||
|
|
||||||
|
def run_cuda_benchmark(n_iters: int) -> float:
|
||||||
|
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(n_iters):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
return (end - start) / n_iters
|
||||||
|
|
||||||
|
# warm-up
|
||||||
|
run_cuda_benchmark(3)
|
||||||
|
|
||||||
|
lat = run_cuda_benchmark(num_iters)
|
||||||
|
|
||||||
|
# free tensors to mitigate OOM when sweeping
|
||||||
|
del key, value, key_cache, value_cache, slot_mapping
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return lat
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
rows = []
|
||||||
|
for exp in range(1, 17):
|
||||||
|
n_tok = 2**exp
|
||||||
|
lat = run_benchmark(
|
||||||
|
num_tokens=n_tok,
|
||||||
|
num_heads=args.num_heads,
|
||||||
|
head_size=args.head_size,
|
||||||
|
block_size=args.block_size,
|
||||||
|
num_blocks=args.num_blocks,
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
|
num_iters=args.iters,
|
||||||
|
benchmark_mode=args.mode,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
rows.append([n_tok, lat * 1e6]) # convert to microseconds
|
||||||
|
|
||||||
|
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
|
||||||
|
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--num-heads", type=int, default=128)
|
||||||
|
parser.add_argument(
|
||||||
|
"--head-size",
|
||||||
|
type=int,
|
||||||
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
|
default=128,
|
||||||
|
)
|
||||||
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
|
parser.add_argument("--num-blocks", type=int, default=128 * 128)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8"],
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--iters", type=int, default=200)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
choices=["cudagraph", "no_graph"],
|
||||||
|
default="cudagraph",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
210
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Normal file
210
benchmarks/kernels/benchmark_reshape_and_cache_flash.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||||
|
triton_reshape_and_cache_flash,
|
||||||
|
)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
from vllm.utils.torch_utils import (
|
||||||
|
STR_DTYPE_TO_TORCH_DTYPE,
|
||||||
|
create_kv_caches_with_random_flash,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_benchmark(
|
||||||
|
num_tokens: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
block_size: int,
|
||||||
|
num_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
kv_cache_layout: str,
|
||||||
|
num_iters: int,
|
||||||
|
implementation: str,
|
||||||
|
benchmark_mode: str,
|
||||||
|
device: str = "cuda",
|
||||||
|
) -> float:
|
||||||
|
"""Return latency (seconds) for given num_tokens."""
|
||||||
|
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
|
||||||
|
|
||||||
|
if implementation not in ("cuda", "triton"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported implementation: {implementation}. "
|
||||||
|
"Only 'cuda' and 'triton' are supported."
|
||||||
|
)
|
||||||
|
if implementation == "triton" and kv_cache_layout == "HND":
|
||||||
|
return float("nan") # Triton does not support HND layout yet.
|
||||||
|
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
# create random key / value tensors [T, H, D].
|
||||||
|
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
|
||||||
|
value = torch.randn_like(key)
|
||||||
|
|
||||||
|
# prepare the slot mapping.
|
||||||
|
# each token is assigned a unique slot in the KV-cache.
|
||||||
|
num_slots = block_size * num_blocks
|
||||||
|
if num_tokens > num_slots:
|
||||||
|
raise ValueError("num_tokens cannot exceed the total number of cache slots")
|
||||||
|
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
key_caches, value_caches = create_kv_caches_with_random_flash(
|
||||||
|
num_blocks,
|
||||||
|
block_size,
|
||||||
|
1, # num_layers
|
||||||
|
num_heads,
|
||||||
|
head_size,
|
||||||
|
kv_cache_dtype,
|
||||||
|
dtype,
|
||||||
|
device=device,
|
||||||
|
cache_layout=kv_cache_layout,
|
||||||
|
)
|
||||||
|
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||||
|
# to free unused memory
|
||||||
|
del key_caches, value_caches
|
||||||
|
|
||||||
|
# compute per-kernel scaling factors for fp8 conversion (if used).
|
||||||
|
k_scale = (key.amax() / 64.0).to(torch.float32)
|
||||||
|
v_scale = (value.amax() / 64.0).to(torch.float32)
|
||||||
|
|
||||||
|
if implementation == "cuda":
|
||||||
|
function_under_test = lambda: ops.reshape_and_cache_flash(
|
||||||
|
key, # noqa: F821
|
||||||
|
value, # noqa: F821
|
||||||
|
key_cache, # noqa: F821
|
||||||
|
value_cache, # noqa: F821
|
||||||
|
slot_mapping, # noqa: F821
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
function_under_test = lambda: triton_reshape_and_cache_flash(
|
||||||
|
key, # noqa: F821
|
||||||
|
value, # noqa: F821
|
||||||
|
key_cache, # noqa: F821
|
||||||
|
value_cache, # noqa: F821
|
||||||
|
slot_mapping, # noqa: F821
|
||||||
|
kv_cache_dtype,
|
||||||
|
k_scale,
|
||||||
|
v_scale,
|
||||||
|
)
|
||||||
|
if benchmark_mode == "cudagraph":
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
function_under_test = lambda: g.replay()
|
||||||
|
|
||||||
|
def run_cuda_benchmark(n_iters: int) -> float:
|
||||||
|
nonlocal key, value, key_cache, value_cache, slot_mapping
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(n_iters):
|
||||||
|
function_under_test()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.perf_counter()
|
||||||
|
return (end - start) / n_iters
|
||||||
|
|
||||||
|
# warm-up
|
||||||
|
run_cuda_benchmark(3)
|
||||||
|
|
||||||
|
lat = run_cuda_benchmark(num_iters)
|
||||||
|
|
||||||
|
# free tensors to mitigate OOM when sweeping
|
||||||
|
del key, value, key_cache, value_cache, slot_mapping
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return lat
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
rows = []
|
||||||
|
for layout in ["NHD", "HND"]:
|
||||||
|
for exp in range(1, 17):
|
||||||
|
n_tok = 2**exp
|
||||||
|
lat = run_benchmark(
|
||||||
|
num_tokens=n_tok,
|
||||||
|
num_heads=args.num_heads,
|
||||||
|
head_size=args.head_size,
|
||||||
|
block_size=args.block_size,
|
||||||
|
num_blocks=args.num_blocks,
|
||||||
|
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
|
||||||
|
kv_cache_dtype=args.kv_cache_dtype,
|
||||||
|
kv_cache_layout=layout,
|
||||||
|
num_iters=args.iters,
|
||||||
|
implementation=args.implementation,
|
||||||
|
benchmark_mode=args.mode,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Benchmark results for implementation {args.implementation}"
|
||||||
|
f" (measuring with {args.mode}):"
|
||||||
|
)
|
||||||
|
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--num-heads", type=int, default=128)
|
||||||
|
parser.add_argument(
|
||||||
|
"--head-size",
|
||||||
|
type=int,
|
||||||
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
|
default=128,
|
||||||
|
)
|
||||||
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
|
parser.add_argument("--num-blocks", type=int, default=128 * 512)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["half", "bfloat16", "float"],
|
||||||
|
default="bfloat16",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--kv-cache-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["auto", "fp8"],
|
||||||
|
default="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--iters", type=int, default=100)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--implementation",
|
||||||
|
type=str,
|
||||||
|
choices=["cuda", "triton"],
|
||||||
|
default="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--mode",
|
||||||
|
type=str,
|
||||||
|
choices=["cudagraph", "no_graph"],
|
||||||
|
default="cudagraph",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
255
benchmarks/kernels/benchmark_rmsnorm.py
Normal file
255
benchmarks/kernels/benchmark_rmsnorm.py
Normal file
@@ -0,0 +1,255 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm import _custom_ops as vllm_ops
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceRMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
orig_dtype = x.dtype
|
||||||
|
x = x.to(torch.float32)
|
||||||
|
if residual is not None:
|
||||||
|
x = x + residual.to(torch.float32)
|
||||||
|
residual = x.to(orig_dtype)
|
||||||
|
|
||||||
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
x = x.to(orig_dtype) * self.weight
|
||||||
|
if residual is None:
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_naive(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
|
||||||
|
naive_norm.weight = nn.Parameter(weight)
|
||||||
|
naive_norm = naive_norm.to(x.device)
|
||||||
|
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
output = naive_norm(x, residual)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_flashinfer(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
fused_add_rmsnorm(x, residual, weight, eps)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
output = rmsnorm(x, weight, eps)
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def rmsnorm_vllm(
|
||||||
|
x: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None = None,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
):
|
||||||
|
orig_shape = x.shape
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.view(-1, residual.shape[-1])
|
||||||
|
|
||||||
|
if residual is not None:
|
||||||
|
vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
|
||||||
|
output = (x, residual)
|
||||||
|
else:
|
||||||
|
out = torch.empty_like(x)
|
||||||
|
vllm_ops.rms_norm(out, x, weight, eps)
|
||||||
|
output = out
|
||||||
|
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
output = (output[0].view(orig_shape), output[1].view(orig_shape))
|
||||||
|
else:
|
||||||
|
output = output.view(orig_shape)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
|
output_naive = rmsnorm_naive(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_flashinfer = rmsnorm_flashinfer(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
output_vllm = rmsnorm_vllm(
|
||||||
|
x.clone(), weight, residual.clone() if residual is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_residual:
|
||||||
|
output_naive = output_naive[0]
|
||||||
|
output_flashinfer = output_flashinfer[0]
|
||||||
|
output_vllm = output_vllm[0]
|
||||||
|
|
||||||
|
print(f"Naive output={output_naive}")
|
||||||
|
print(f"FlashInfer output={output_flashinfer}")
|
||||||
|
print(f"vLLM output={output_vllm}")
|
||||||
|
|
||||||
|
if torch.allclose(
|
||||||
|
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
|
||||||
|
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
|
||||||
|
print("✅ All implementations match")
|
||||||
|
else:
|
||||||
|
print("❌ Implementations differ")
|
||||||
|
|
||||||
|
|
||||||
|
batch_size_range = [2**i for i in range(0, 7, 2)]
|
||||||
|
seq_length_range = [2**i for i in range(6, 11, 1)]
|
||||||
|
head_num_range = [32, 48]
|
||||||
|
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
|
||||||
|
|
||||||
|
|
||||||
|
def get_benchmark(use_residual):
|
||||||
|
@triton.testing.perf_report(
|
||||||
|
triton.testing.Benchmark(
|
||||||
|
x_names=["head_num", "batch_size", "seq_len"],
|
||||||
|
x_vals=[list(_) for _ in configs],
|
||||||
|
line_arg="provider",
|
||||||
|
line_vals=["huggingface", "flashinfer", "vllm"],
|
||||||
|
line_names=["HuggingFace", "FlashInfer", "vLLM"],
|
||||||
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
|
ylabel="us",
|
||||||
|
plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
|
||||||
|
args={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def benchmark(head_num, batch_size, seq_len, provider):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
hidden_size = head_num * 128 # assuming head_dim = 128
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
|
||||||
|
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
|
||||||
|
residual = torch.randn_like(x) if use_residual else None
|
||||||
|
|
||||||
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
|
|
||||||
|
if provider == "huggingface":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rmsnorm_naive(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
elif provider == "flashinfer":
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rmsnorm_flashinfer(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rmsnorm_vllm(
|
||||||
|
x.clone(),
|
||||||
|
weight,
|
||||||
|
residual.clone() if residual is not None else None,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seq-len",
|
||||||
|
type=int,
|
||||||
|
default=128,
|
||||||
|
help="Sequence length",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hidden-size",
|
||||||
|
type=int,
|
||||||
|
default=4096,
|
||||||
|
help="Hidden size (2nd dimension) of the sequence",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-residual", action="store_true", help="Whether to use residual connection"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--save-path",
|
||||||
|
type=str,
|
||||||
|
default="./configs/rmsnorm/",
|
||||||
|
help="Path to save rmsnorm benchmark results",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Run correctness test
|
||||||
|
calculate_diff(
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
seq_len=args.seq_len,
|
||||||
|
hidden_size=args.hidden_size,
|
||||||
|
use_residual=args.use_residual,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the benchmark function with proper use_residual setting
|
||||||
|
benchmark = get_benchmark(args.use_residual)
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
@@ -1,121 +1,106 @@
|
|||||||
import argparse
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from itertools import accumulate
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import Optional
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
import nvtx
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
batch_size_range = [2**i for i in range(0, 8, 2)]
|
||||||
|
seq_len_range = [2**i for i in range(6, 10, 1)]
|
||||||
|
num_heads_range = [32, 48]
|
||||||
|
configs = list(itertools.product(batch_size_range, seq_len_range, num_heads_range))
|
||||||
|
|
||||||
|
|
||||||
def benchmark_rope_kernels_multi_lora(
|
def get_benchmark(head_size, rotary_dim, is_neox_style, device):
|
||||||
is_neox_style: bool,
|
@triton.testing.perf_report(
|
||||||
batch_size: int,
|
triton.testing.Benchmark(
|
||||||
seq_len: int,
|
x_names=["batch_size", "seq_len", "num_heads"],
|
||||||
num_heads: int,
|
x_vals=[list(_) for _ in configs],
|
||||||
head_size: int,
|
line_arg="provider",
|
||||||
rotary_dim: Optional[int],
|
line_vals=["torch", "flashinfer", "vllm"],
|
||||||
dtype: torch.dtype,
|
line_names=["PyTorch", "FlashInfer", "vLLM"],
|
||||||
seed: int,
|
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
|
||||||
device: str,
|
ylabel="us",
|
||||||
max_position: int = 8192,
|
plot_name=f"rope-perf{'-neox-style' if is_neox_style else ''}",
|
||||||
base: int = 10000,
|
args={},
|
||||||
) -> None:
|
)
|
||||||
torch.random.manual_seed(seed)
|
)
|
||||||
if torch.cuda.is_available():
|
def benchmark(batch_size, seq_len, num_heads, provider):
|
||||||
torch.cuda.manual_seed(seed)
|
dtype = torch.bfloat16
|
||||||
torch.set_default_device(device)
|
max_position = 8192
|
||||||
if rotary_dim is None:
|
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
|
||||||
rotary_dim = head_size
|
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||||
# silulating serving 4 LoRAs
|
rope = rope.to(dtype=dtype, device=device)
|
||||||
scaling_factors = [1, 2, 4, 8]
|
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)
|
||||||
# batched RoPE can take multiple scaling factors
|
|
||||||
batched_rope = get_rope(head_size, rotary_dim, max_position, base,
|
|
||||||
is_neox_style, {
|
|
||||||
"type": "linear",
|
|
||||||
"factor": tuple(scaling_factors)
|
|
||||||
})
|
|
||||||
# non-batched RoPE takes only one scaling factor, we create multiple
|
|
||||||
# instances to simulate the same behavior
|
|
||||||
non_batched_ropes = []
|
|
||||||
for scaling_factor in scaling_factors:
|
|
||||||
non_batched_ropes.append(
|
|
||||||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style,
|
|
||||||
{
|
|
||||||
"type": "linear",
|
|
||||||
"factor": (scaling_factor, )
|
|
||||||
}))
|
|
||||||
|
|
||||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
positions = torch.randint(0, max_position, (batch_size, seq_len), device=device)
|
||||||
query = torch.randn(batch_size,
|
query = torch.randn(
|
||||||
seq_len,
|
(batch_size, seq_len, num_heads * head_size), dtype=dtype, device=device
|
||||||
num_heads * head_size,
|
)
|
||||||
dtype=dtype)
|
key = torch.randn_like(query)
|
||||||
key = torch.randn_like(query)
|
|
||||||
|
|
||||||
# create query offsets for batched RoPE, we concat multiple kv cache
|
quantiles = [0.5, 0.2, 0.8]
|
||||||
# together and each query needs to find the right kv cache of its type
|
|
||||||
offset_map = torch.tensor(
|
|
||||||
list(
|
|
||||||
accumulate([0] + [
|
|
||||||
max_position * scaling_factor * 2
|
|
||||||
for scaling_factor in scaling_factors[:-1]
|
|
||||||
])))
|
|
||||||
query_types = torch.randint(0,
|
|
||||||
len(scaling_factors), (batch_size, seq_len),
|
|
||||||
device=device)
|
|
||||||
# map query types to offsets
|
|
||||||
query_offsets = offset_map[query_types]
|
|
||||||
# the kernel takes flattened offsets
|
|
||||||
flatten_offsets = query_offsets.flatten()
|
|
||||||
|
|
||||||
# batched queries of the same type together for non-batched RoPE
|
if provider == "torch":
|
||||||
queries = [query[query_types == i] for i in range(len(scaling_factors))]
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
keys = [key[query_types == i] for i in range(len(scaling_factors))]
|
lambda: rope.forward_native(positions, query.clone(), key.clone()),
|
||||||
packed_qkr = zip(queries, keys, non_batched_ropes)
|
quantiles=quantiles,
|
||||||
# synchronize before start timing
|
)
|
||||||
torch.cuda.synchronize()
|
elif provider == "flashinfer":
|
||||||
with nvtx.annotate("non-batched", color="yellow"):
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
for q, k, r in packed_qkr:
|
lambda: torch.ops.vllm.flashinfer_rotary_embedding(
|
||||||
r.forward(positions, q, k)
|
positions,
|
||||||
torch.cuda.synchronize()
|
query.clone(),
|
||||||
with nvtx.annotate("batched", color="green"):
|
key.clone(),
|
||||||
batched_rope.forward(positions, query, key, flatten_offsets)
|
head_size,
|
||||||
torch.cuda.synchronize()
|
cos_sin_cache,
|
||||||
|
is_neox_style,
|
||||||
|
),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
||||||
|
lambda: rope.forward_cuda(positions, query.clone(), key.clone()),
|
||||||
|
quantiles=quantiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||||
|
|
||||||
|
return benchmark
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = FlexibleArgumentParser(
|
||||||
description="Benchmark the rotary embedding kernels.")
|
description="Benchmark the rotary embedding kernels."
|
||||||
|
)
|
||||||
parser.add_argument("--is-neox-style", type=bool, default=True)
|
parser.add_argument("--is-neox-style", type=bool, default=True)
|
||||||
parser.add_argument("--batch-size", type=int, default=16)
|
parser.add_argument("--batch-size", type=int, default=16)
|
||||||
parser.add_argument("--seq-len", type=int, default=512)
|
parser.add_argument("--seq-len", type=int, default=512)
|
||||||
parser.add_argument("--num-heads", type=int, default=8)
|
parser.add_argument("--num-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument(
|
||||||
type=int,
|
"--head-size",
|
||||||
choices=[64, 80, 96, 112, 128, 256],
|
type=int,
|
||||||
default=128)
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
default=128,
|
||||||
parser.add_argument("--dtype",
|
|
||||||
type=str,
|
|
||||||
choices=["bfloat16", "float"],
|
|
||||||
default="float")
|
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
|
||||||
parser.add_argument("--device",
|
|
||||||
type=str,
|
|
||||||
choices=["cuda:0", "cuda:1"],
|
|
||||||
default="cuda:0")
|
|
||||||
args = parser.parse_args()
|
|
||||||
print(args)
|
|
||||||
|
|
||||||
benchmark_rope_kernels_multi_lora(
|
|
||||||
is_neox_style=args.is_neox_style,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
seq_len=args.seq_len,
|
|
||||||
num_heads=args.num_heads,
|
|
||||||
head_size=args.head_size,
|
|
||||||
rotary_dim=args.rotary_dim,
|
|
||||||
dtype=getattr(torch, args.dtype),
|
|
||||||
seed=args.seed,
|
|
||||||
device=args.device,
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dtype", type=str, choices=["bfloat16", "float"], default="float"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
|
||||||
|
)
|
||||||
|
parser.add_argument("--save-path", type=str, default="./configs/rope/")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Get the benchmark function
|
||||||
|
benchmark = get_benchmark(
|
||||||
|
args.head_size, args.rotary_dim, args.is_neox_style, args.device
|
||||||
|
)
|
||||||
|
# Run performance benchmark
|
||||||
|
benchmark.run(print_data=True, save_path=args.save_path)
|
||||||
|
|||||||
94
benchmarks/kernels/benchmark_shapes.py
Normal file
94
benchmarks/kernels/benchmark_shapes.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"ideal": [[4 * 256 * 32, 256 * 32]],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP1": [
|
||||||
|
[4096, 6144],
|
||||||
|
[4096, 4096],
|
||||||
|
[4096, 28672],
|
||||||
|
[14336, 4096],
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP2": [
|
||||||
|
[4096, 3072],
|
||||||
|
[2048, 4096],
|
||||||
|
[4096, 14336],
|
||||||
|
[7168, 4096],
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-7B-v0.1/TP4": [
|
||||||
|
[4096, 1536],
|
||||||
|
[1024, 4096],
|
||||||
|
[4096, 7168],
|
||||||
|
[3584, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP1": [
|
||||||
|
[4096, 12288],
|
||||||
|
[4096, 4096],
|
||||||
|
[4096, 22016],
|
||||||
|
[11008, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP2": [
|
||||||
|
[4096, 6144],
|
||||||
|
[2048, 4096],
|
||||||
|
[4096, 11008],
|
||||||
|
[5504, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf/TP4": [
|
||||||
|
[4096, 3072],
|
||||||
|
[1024, 4096],
|
||||||
|
[4096, 5504],
|
||||||
|
[2752, 4096],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP1": [
|
||||||
|
[5120, 15360],
|
||||||
|
[5120, 5120],
|
||||||
|
[5120, 27648],
|
||||||
|
[13824, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP2": [
|
||||||
|
[5120, 7680],
|
||||||
|
[2560, 5120],
|
||||||
|
[5120, 13824],
|
||||||
|
[6912, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf/TP4": [
|
||||||
|
[5120, 3840],
|
||||||
|
[1280, 5120],
|
||||||
|
[5120, 6912],
|
||||||
|
[3456, 5120],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP1": [
|
||||||
|
[8192, 10240],
|
||||||
|
[8192, 8192],
|
||||||
|
[8192, 57344],
|
||||||
|
[28672, 8192],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP2": [
|
||||||
|
[8192, 5120],
|
||||||
|
[4096, 8192],
|
||||||
|
[8192, 28672],
|
||||||
|
[14336, 8192],
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf/TP4": [
|
||||||
|
[8192, 2560],
|
||||||
|
[2048, 8192],
|
||||||
|
[8192, 14336],
|
||||||
|
[7168, 8192],
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
WEIGHT_SHAPES_MOE = {
|
||||||
|
"mistralai/Mixtral-8x7B-Instruct-v0.1": [
|
||||||
|
[8, 2, 4096, 28672],
|
||||||
|
[8, 2, 14336, 4096],
|
||||||
|
],
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite": [
|
||||||
|
[64, 6, 2048, 1408],
|
||||||
|
],
|
||||||
|
"ibm-granite/granite-3.0-1b-a400m": [
|
||||||
|
[32, 8, 1024, 1024],
|
||||||
|
],
|
||||||
|
"ibm-granite/granite-3.0-3b-a800m": [
|
||||||
|
[40, 8, 1024, 1536],
|
||||||
|
],
|
||||||
|
}
|
||||||
720
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
720
benchmarks/kernels/benchmark_silu_mul_fp8_quant.py
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""
|
||||||
|
Comprehensive 3-way SiLU Benchmark Suite
|
||||||
|
|
||||||
|
This benchmark compares three SiLU implementations:
|
||||||
|
1. SiLU V2 (CUDA) - Optimized CUDA kernel implementation
|
||||||
|
2. Triton Kernel - Triton-based implementation
|
||||||
|
|
||||||
|
The suite generates detailed performance comparisons including:
|
||||||
|
- Memory bandwidth utilization
|
||||||
|
- Speedup ratios (baseline vs optimized implementations)
|
||||||
|
- Performance across different expert configurations and token distributions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||||
|
persistent_masked_m_silu_mul_quant,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import tl, triton
|
||||||
|
from vllm.utils.deep_gemm import is_deep_gemm_e8m0_used
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _silu_mul_fp8_quant_deep_gemm(
|
||||||
|
# Pointers ------------------------------------------------------------
|
||||||
|
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||||
|
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||||
|
y_s_ptr, # 16-bit scales (E, T, G)
|
||||||
|
counts_ptr, # int32 num tokens per expert (E)
|
||||||
|
# Sizes ---------------------------------------------------------------
|
||||||
|
H: tl.constexpr, # hidden dimension (per output)
|
||||||
|
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||||
|
# Strides for input (elements) ---------------------------------------
|
||||||
|
stride_i_e,
|
||||||
|
stride_i_t,
|
||||||
|
stride_i_h,
|
||||||
|
# Strides for y_q (elements) -----------------------------------------
|
||||||
|
stride_yq_e,
|
||||||
|
stride_yq_t,
|
||||||
|
stride_yq_h,
|
||||||
|
# Strides for y_s (elements) -----------------------------------------
|
||||||
|
stride_ys_e,
|
||||||
|
stride_ys_t,
|
||||||
|
stride_ys_g,
|
||||||
|
# Stride for counts (elements)
|
||||||
|
stride_counts_e,
|
||||||
|
# Numeric params ------------------------------------------------------
|
||||||
|
eps: tl.constexpr,
|
||||||
|
fp8_min: tl.constexpr,
|
||||||
|
fp8_max: tl.constexpr,
|
||||||
|
use_ue8m0: tl.constexpr,
|
||||||
|
# Meta ---------------------------------------------------------------
|
||||||
|
BLOCK: tl.constexpr,
|
||||||
|
NUM_STAGES: tl.constexpr,
|
||||||
|
):
|
||||||
|
G = H // GROUP_SIZE
|
||||||
|
|
||||||
|
# map program id -> (e, g)
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
e = pid // G
|
||||||
|
g = pid % G
|
||||||
|
|
||||||
|
e = e.to(tl.int64)
|
||||||
|
g = g.to(tl.int64)
|
||||||
|
|
||||||
|
# number of valid tokens for this expert
|
||||||
|
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||||
|
|
||||||
|
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||||
|
mask = cols < BLOCK
|
||||||
|
|
||||||
|
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||||
|
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||||
|
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||||
|
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
|
||||||
|
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||||
|
|
||||||
|
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||||
|
gate = tl.load(
|
||||||
|
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
|
||||||
|
).to(tl.float32)
|
||||||
|
up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0)
|
||||||
|
|
||||||
|
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||||
|
y = gate * up
|
||||||
|
|
||||||
|
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||||
|
if use_ue8m0:
|
||||||
|
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||||
|
|
||||||
|
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||||
|
|
||||||
|
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||||
|
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||||
|
|
||||||
|
|
||||||
|
def silu_mul_fp8_quant_deep_gemm_triton(
|
||||||
|
y: torch.Tensor, # (E, T, 2*H)
|
||||||
|
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||||
|
num_parallel_tokens,
|
||||||
|
group_size: int = 128,
|
||||||
|
eps: float = 1e-10,
|
||||||
|
expert_offsets: torch.Tensor = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||||
|
|
||||||
|
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||||
|
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||||
|
|
||||||
|
Returns `(y_q, y_s)` where
|
||||||
|
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||||
|
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||||
|
"""
|
||||||
|
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||||
|
E, T, H2 = y.shape
|
||||||
|
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||||
|
H = H2 // 2
|
||||||
|
G = (H + group_size - 1) // group_size
|
||||||
|
assert H % group_size == 0, "H must be divisible by group_size"
|
||||||
|
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E, (
|
||||||
|
"tokens_per_expert must be shape (E,)"
|
||||||
|
)
|
||||||
|
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
|
||||||
|
|
||||||
|
# allocate outputs
|
||||||
|
fp8_dtype = torch.float8_e4m3fn
|
||||||
|
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||||
|
|
||||||
|
# strides (elements)
|
||||||
|
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||||
|
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||||
|
|
||||||
|
# desired scale strides (elements): (T*G, 1, T)
|
||||||
|
stride_ys_e = T * G
|
||||||
|
stride_ys_t = 1
|
||||||
|
stride_ys_g = T
|
||||||
|
y_s = torch.empty_strided(
|
||||||
|
(E, T, G),
|
||||||
|
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=y.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||||
|
|
||||||
|
# Static grid over experts and H-groups.
|
||||||
|
# A loop inside the kernel handles the token dim
|
||||||
|
grid = (E * G,)
|
||||||
|
|
||||||
|
f_info = torch.finfo(fp8_dtype)
|
||||||
|
fp8_max = f_info.max
|
||||||
|
fp8_min = f_info.min
|
||||||
|
|
||||||
|
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||||
|
y,
|
||||||
|
y_q,
|
||||||
|
y_s,
|
||||||
|
tokens_per_expert,
|
||||||
|
H,
|
||||||
|
group_size,
|
||||||
|
stride_i_e,
|
||||||
|
stride_i_t,
|
||||||
|
stride_i_h,
|
||||||
|
stride_yq_e,
|
||||||
|
stride_yq_t,
|
||||||
|
stride_yq_h,
|
||||||
|
stride_ys_e,
|
||||||
|
stride_ys_t,
|
||||||
|
stride_ys_g,
|
||||||
|
stride_cnt_e,
|
||||||
|
eps,
|
||||||
|
fp8_min,
|
||||||
|
fp8_max,
|
||||||
|
is_deep_gemm_e8m0_used(),
|
||||||
|
BLOCK=group_size,
|
||||||
|
NUM_STAGES=4,
|
||||||
|
num_warps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return y_q, y_s
|
||||||
|
|
||||||
|
|
||||||
|
# Parse generation strategies
|
||||||
|
strategies = ["random_imbalanced", "uniform", "max_t"]
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark(
|
||||||
|
kernel: Callable,
|
||||||
|
E: int,
|
||||||
|
T: int,
|
||||||
|
H: int,
|
||||||
|
total_tokens: int,
|
||||||
|
num_parallel_tokens: int = 64,
|
||||||
|
G: int = 128,
|
||||||
|
runs: int = 200,
|
||||||
|
num_warmups: int = 20,
|
||||||
|
gen_strategy: str = "default",
|
||||||
|
iterations_per_run: int = 20,
|
||||||
|
):
|
||||||
|
def generate_data(seed_offset=0):
|
||||||
|
"""Generate input data with given seed offset"""
|
||||||
|
current_platform.seed_everything(42 + seed_offset)
|
||||||
|
y = torch.rand((E, T, 2 * H), dtype=torch.bfloat16, device="cuda").contiguous()
|
||||||
|
|
||||||
|
if gen_strategy == "random_imbalanced":
|
||||||
|
|
||||||
|
def generate_expert_loads(n_e, total_tokens, ratio, device="cuda"):
|
||||||
|
mean = total_tokens // n_e
|
||||||
|
min_max = mean // ratio
|
||||||
|
e = torch.ones(size=(E,), dtype=torch.int64, device=device) * mean
|
||||||
|
e[0] = min_max
|
||||||
|
r = torch.rand(size=(E - 1,))
|
||||||
|
r /= r.sum()
|
||||||
|
r *= total_tokens - min_max
|
||||||
|
r = r.round().long()
|
||||||
|
e[1:] = r.to(device=device)
|
||||||
|
return e
|
||||||
|
|
||||||
|
tokens_per_expert = generate_expert_loads(E, total_tokens, 0.7, "cuda")
|
||||||
|
elif gen_strategy == "uniform":
|
||||||
|
r = torch.rand(size=(E,))
|
||||||
|
r /= r.sum()
|
||||||
|
r *= total_tokens
|
||||||
|
r = r.round().long()
|
||||||
|
tokens_per_expert = r
|
||||||
|
elif gen_strategy == "max_t":
|
||||||
|
tokens_per_expert = torch.empty(size=(E,), dtype=torch.int32, device="cuda")
|
||||||
|
tokens_per_expert.fill_(total_tokens / E)
|
||||||
|
elif gen_strategy == "first_t":
|
||||||
|
tokens_per_expert = torch.zeros(size=(E,), dtype=torch.int32, device="cuda")
|
||||||
|
tokens_per_expert[0] = min(T, total_tokens)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown generation strategy: {gen_strategy}")
|
||||||
|
return y, tokens_per_expert
|
||||||
|
|
||||||
|
dataset_count = 4
|
||||||
|
# Pre-generate different input matrices for each iteration to avoid cache effects
|
||||||
|
data_sets = [generate_data(i) for i in range(dataset_count)]
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
y, tokens_per_expert = data_sets[0]
|
||||||
|
for _ in range(num_warmups):
|
||||||
|
kernel(
|
||||||
|
y, tokens_per_expert, num_parallel_tokens=num_parallel_tokens, group_size=G
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.Event(enable_timing=True)
|
||||||
|
end_event = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
latencies: list[float] = []
|
||||||
|
for _ in range(runs):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event.record()
|
||||||
|
for i in range(iterations_per_run):
|
||||||
|
y, tokens_per_expert = data_sets[i % dataset_count]
|
||||||
|
kernel(
|
||||||
|
y,
|
||||||
|
tokens_per_expert,
|
||||||
|
num_parallel_tokens=num_parallel_tokens,
|
||||||
|
group_size=G,
|
||||||
|
)
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
|
||||||
|
total_time_ms = start_event.elapsed_time(end_event)
|
||||||
|
per_iter_time_ms = total_time_ms / iterations_per_run
|
||||||
|
latencies.append(per_iter_time_ms)
|
||||||
|
|
||||||
|
# Use median instead of average for better outlier handling
|
||||||
|
median_time_ms = np.median(latencies)
|
||||||
|
median_time_s = median_time_ms / 1000
|
||||||
|
|
||||||
|
# Calculate actual work done (using first dataset for consistency)
|
||||||
|
_, tokens_per_expert = data_sets[0]
|
||||||
|
actual_tokens = tokens_per_expert.sum().item()
|
||||||
|
actual_elements = actual_tokens * H
|
||||||
|
|
||||||
|
# GFLOPS: operations per element = exp + 3 muls + 1 div + quantization ops ≈ 8 ops
|
||||||
|
ops_per_element = 8
|
||||||
|
total_ops = actual_elements * ops_per_element
|
||||||
|
gflops = total_ops / median_time_s / 1e9
|
||||||
|
|
||||||
|
# Memory bandwidth: bfloat16 inputs (2 bytes), fp8 output (1 byte), scales (4 bytes)
|
||||||
|
input_bytes = actual_tokens * 2 * H * 2 # 2*H bfloat16 inputs
|
||||||
|
output_bytes = actual_tokens * H * 1 # H fp8 outputs
|
||||||
|
scale_bytes = actual_tokens * (H // G) * 4 # scales in float32
|
||||||
|
total_bytes = input_bytes + output_bytes + scale_bytes
|
||||||
|
memory_bw = total_bytes / median_time_s / 1e9
|
||||||
|
|
||||||
|
HOPPER_BANDWIDTH_TBPS = 3.35
|
||||||
|
return (
|
||||||
|
median_time_ms,
|
||||||
|
gflops,
|
||||||
|
memory_bw,
|
||||||
|
(memory_bw / (HOPPER_BANDWIDTH_TBPS * 1024)) * 100,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_comparison_plot(
|
||||||
|
ratios, silu_v2_times, triton_times, config_labels, strategy_name, id
|
||||||
|
):
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(18, 6))
|
||||||
|
|
||||||
|
# Configure x-axis positions
|
||||||
|
x = np.arange(len(config_labels))
|
||||||
|
width = 0.25
|
||||||
|
|
||||||
|
# Execution Time plot (lower is better)
|
||||||
|
ax.bar(x, silu_v2_times, width, label="SiLU V2 (CUDA)", alpha=0.8, color="blue")
|
||||||
|
ax.bar(
|
||||||
|
x + width, triton_times, width, label="Triton Kernel", alpha=0.8, color="green"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add speedup labels over each bar trio
|
||||||
|
for i in range(len(x)):
|
||||||
|
triton_v2_speedup = ratios[i][1] # triton/v2
|
||||||
|
max_height = max(silu_v2_times[i], triton_times[i])
|
||||||
|
|
||||||
|
# Triton/V2 speedup
|
||||||
|
ax.text(
|
||||||
|
x[i] + width / 2,
|
||||||
|
max_height + max_height * 0.02,
|
||||||
|
f"{triton_v2_speedup:.2f}x",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontweight="bold",
|
||||||
|
fontsize=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlabel("Configuration")
|
||||||
|
ax.set_ylabel("% Utilization")
|
||||||
|
ax.set_title(
|
||||||
|
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||||
|
)
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
return fig, ax
|
||||||
|
|
||||||
|
|
||||||
|
def create_combined_plot(all_results):
|
||||||
|
num_strategies = len(all_results)
|
||||||
|
fig, axes = plt.subplots(num_strategies, 1, figsize=(22, 7 * num_strategies))
|
||||||
|
|
||||||
|
if num_strategies == 1:
|
||||||
|
axes = [axes]
|
||||||
|
|
||||||
|
for idx, (
|
||||||
|
strategy_name,
|
||||||
|
all_ratios,
|
||||||
|
all_silu_v2_results,
|
||||||
|
all_triton_results,
|
||||||
|
config_labels,
|
||||||
|
config_x_axis,
|
||||||
|
) in enumerate(all_results):
|
||||||
|
ax = axes[idx]
|
||||||
|
|
||||||
|
# Flatten the nested results to get bandwidth percentages for plotting
|
||||||
|
silu_v2_bandwidths = []
|
||||||
|
triton_bandwidths = []
|
||||||
|
flat_ratios = []
|
||||||
|
|
||||||
|
for config_results in all_silu_v2_results:
|
||||||
|
for result in config_results:
|
||||||
|
silu_v2_bandwidths.append(result[3]) # bandwidth percentage
|
||||||
|
|
||||||
|
for config_results in all_triton_results:
|
||||||
|
for result in config_results:
|
||||||
|
triton_bandwidths.append(result[3]) # bandwidth percentage
|
||||||
|
|
||||||
|
for config_ratios in all_ratios:
|
||||||
|
for ratio in config_ratios:
|
||||||
|
flat_ratios.append(ratio)
|
||||||
|
|
||||||
|
# Configure x-axis positions
|
||||||
|
x = np.arange(len(config_labels))
|
||||||
|
width = 0.25
|
||||||
|
|
||||||
|
# Bandwidth utilization plot (higher is better)
|
||||||
|
ax.bar(
|
||||||
|
x,
|
||||||
|
silu_v2_bandwidths,
|
||||||
|
width,
|
||||||
|
label="SiLU V2 (CUDA)",
|
||||||
|
alpha=0.8,
|
||||||
|
color="blue",
|
||||||
|
)
|
||||||
|
ax.bar(
|
||||||
|
x + width,
|
||||||
|
triton_bandwidths,
|
||||||
|
width,
|
||||||
|
label="Triton Kernel",
|
||||||
|
alpha=0.8,
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add speedup labels over each bar trio
|
||||||
|
for i in range(len(x)):
|
||||||
|
triton_v2_speedup = flat_ratios[i] # triton/v2
|
||||||
|
max_height = max(silu_v2_bandwidths[i], triton_bandwidths[i])
|
||||||
|
|
||||||
|
# Triton/V2 speedup
|
||||||
|
ax.text(
|
||||||
|
x[i] + width / 2,
|
||||||
|
max_height + max_height * 0.02,
|
||||||
|
f"{triton_v2_speedup:.2f}x",
|
||||||
|
ha="center",
|
||||||
|
va="bottom",
|
||||||
|
fontweight="bold",
|
||||||
|
fontsize=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlabel("Configuration")
|
||||||
|
ax.set_ylabel("% Utilization")
|
||||||
|
ax.set_title(
|
||||||
|
f"Memory Bandwidth Utilization (%) - {strategy_name}\n(Higher is Better)"
|
||||||
|
)
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(config_labels, rotation=45, ha="right")
|
||||||
|
ax.legend()
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
filename = "silu_benchmark_combined_3way.png"
|
||||||
|
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
outer_dim = 7168
|
||||||
|
configs = [
|
||||||
|
# DeepSeekV3 Configs
|
||||||
|
# (1, 56, 7168),
|
||||||
|
(8, 1024, 7168),
|
||||||
|
# (32, 56, 7168),
|
||||||
|
# DeepSeekV3 Configs
|
||||||
|
(32, 1024, 7168),
|
||||||
|
# DeepSeekV3 Configs
|
||||||
|
(256, 1024, 7168),
|
||||||
|
]
|
||||||
|
|
||||||
|
runs = 100
|
||||||
|
num_warmups = 20
|
||||||
|
|
||||||
|
strategy_descriptions = {
|
||||||
|
"uniform": "Uniform Random",
|
||||||
|
"random_imbalanced": "Imbalanced Random",
|
||||||
|
"max_t": "Even Assignment",
|
||||||
|
"first_t": "experts[0] = T, experts[1:] = 0",
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name()}")
|
||||||
|
print(f"Testing strategies: {', '.join(strategies)}")
|
||||||
|
print(f"Configurations: {len(configs)} configs")
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
# Run benchmarks for each strategy
|
||||||
|
for id, strategy in enumerate(strategies):
|
||||||
|
print(f"\n{'=' * 60}")
|
||||||
|
print(f"Testing strategy: {strategy_descriptions[strategy]}")
|
||||||
|
print(f"{'=' * 60}")
|
||||||
|
|
||||||
|
# Collect benchmark data for all three algorithms
|
||||||
|
config_labels = []
|
||||||
|
config_x_axis = []
|
||||||
|
all_silu_v2_results = []
|
||||||
|
all_triton_results = []
|
||||||
|
all_ratios = []
|
||||||
|
|
||||||
|
for E, T, H in configs:
|
||||||
|
total_tokens_config = []
|
||||||
|
for i in [8, 16, 32, 64, 128, 256, 512]:
|
||||||
|
if i <= T:
|
||||||
|
total_tokens_config.append(i * E)
|
||||||
|
config_x_axis.append(total_tokens_config)
|
||||||
|
|
||||||
|
silu_v2_results = []
|
||||||
|
triton_results = []
|
||||||
|
ratios = []
|
||||||
|
|
||||||
|
for total_tokens in total_tokens_config:
|
||||||
|
config_label = f"E={E},T={T},H={H},TT={total_tokens}"
|
||||||
|
config_labels.append(config_label)
|
||||||
|
|
||||||
|
# SiLU V2 (CUDA kernel) results
|
||||||
|
time_ms_silu_v2, gflops, gbps, perc = benchmark(
|
||||||
|
persistent_masked_m_silu_mul_quant,
|
||||||
|
E,
|
||||||
|
T,
|
||||||
|
H,
|
||||||
|
total_tokens,
|
||||||
|
runs=runs,
|
||||||
|
num_warmups=num_warmups,
|
||||||
|
gen_strategy=strategy,
|
||||||
|
)
|
||||||
|
silu_v2_results.append((time_ms_silu_v2, gflops, gbps, perc))
|
||||||
|
|
||||||
|
# Triton kernel results
|
||||||
|
time_ms_triton, gflops, gbps, perc = benchmark(
|
||||||
|
silu_mul_fp8_quant_deep_gemm_triton,
|
||||||
|
E,
|
||||||
|
T,
|
||||||
|
H,
|
||||||
|
total_tokens,
|
||||||
|
runs=runs,
|
||||||
|
num_warmups=num_warmups,
|
||||||
|
gen_strategy=strategy,
|
||||||
|
)
|
||||||
|
triton_results.append((time_ms_triton, gflops, gbps, perc))
|
||||||
|
|
||||||
|
# Calculate speedup ratios (triton baseline / implementation)
|
||||||
|
triton_v2_ratio = time_ms_triton / time_ms_silu_v2
|
||||||
|
ratios.append(triton_v2_ratio)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Completed: {config_label}:"
|
||||||
|
f" V2: {time_ms_silu_v2:.3f}ms,"
|
||||||
|
f" Triton: {time_ms_triton:.3f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
all_silu_v2_results.append(silu_v2_results)
|
||||||
|
all_triton_results.append(triton_results)
|
||||||
|
all_ratios.append(ratios)
|
||||||
|
|
||||||
|
# Store results for combined plotting
|
||||||
|
all_results.append(
|
||||||
|
(
|
||||||
|
strategy_descriptions[strategy],
|
||||||
|
all_ratios,
|
||||||
|
all_silu_v2_results,
|
||||||
|
all_triton_results,
|
||||||
|
config_labels,
|
||||||
|
config_x_axis,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print summary table for this strategy
|
||||||
|
print(f"\nSummary Table - {strategy_descriptions[strategy]}:")
|
||||||
|
print(f" {'V2 Time(ms)':<12} {'Triton Time(ms)':<14} {'Triton/V2':<10}")
|
||||||
|
print("-" * 90)
|
||||||
|
|
||||||
|
for i, (E, T, H) in enumerate(configs):
|
||||||
|
# Get the first result for each config (simplifying for summary)
|
||||||
|
v2_time = silu_v2_results[i][0]
|
||||||
|
triton_time = triton_results[i][0]
|
||||||
|
triton_v2_speedup = triton_time / v2_time
|
||||||
|
config_label = f"E={E:3d},T={T:4d},H={H:4d}"
|
||||||
|
print(
|
||||||
|
f"{config_label:<20} {v2_time:8.5f} {triton_time:10.5f} "
|
||||||
|
f"{triton_v2_speedup:8.2f}x"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_total_tokens_plot(all_results):
|
||||||
|
num_strategies = len(all_results)
|
||||||
|
num_configs = len(configs)
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(
|
||||||
|
num_strategies, num_configs * 2, figsize=(32, 8 * num_strategies)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add main title to the entire figure
|
||||||
|
fig.suptitle(
|
||||||
|
"Performance Analysis: Speedup vs Bandwidth Utilization (SiLU V2, and Triton)",
|
||||||
|
fontsize=18,
|
||||||
|
fontweight="bold",
|
||||||
|
y=0.98,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle single strategy case
|
||||||
|
if num_strategies == 1:
|
||||||
|
axs = axs.reshape(1, -1)
|
||||||
|
|
||||||
|
# Handle single config case
|
||||||
|
if num_configs == 1:
|
||||||
|
axs = axs.reshape(-1, 2)
|
||||||
|
|
||||||
|
for strategy_idx, result in enumerate(all_results):
|
||||||
|
(
|
||||||
|
strategy_name,
|
||||||
|
all_ratios,
|
||||||
|
all_silu_v2_results,
|
||||||
|
all_triton_results,
|
||||||
|
config_labels,
|
||||||
|
config_x_axis,
|
||||||
|
) = result
|
||||||
|
|
||||||
|
for config_idx in range(num_configs):
|
||||||
|
# Speedup plot (left column)
|
||||||
|
ax_speedup = axs[strategy_idx, config_idx * 2]
|
||||||
|
# Bandwidth plot (right column)
|
||||||
|
ax_bandwidth = axs[strategy_idx, config_idx * 2 + 1]
|
||||||
|
|
||||||
|
E, T, H = configs[config_idx]
|
||||||
|
ratios = all_ratios[config_idx]
|
||||||
|
total_tokens_values = config_x_axis[config_idx]
|
||||||
|
|
||||||
|
# Extract speedup ratios
|
||||||
|
triton_v2_ratios = [ratio for ratio in ratios]
|
||||||
|
|
||||||
|
# Extract bandwidth percentages for all implementations
|
||||||
|
v2_bandwidth_percentages = [
|
||||||
|
result[3] for result in all_silu_v2_results[config_idx]
|
||||||
|
]
|
||||||
|
triton_bandwidth_percentages = [
|
||||||
|
result[3] for result in all_triton_results[config_idx]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Plot speedup ratios vs total tokens (left plot)
|
||||||
|
ax_speedup.plot(
|
||||||
|
total_tokens_values,
|
||||||
|
triton_v2_ratios,
|
||||||
|
"go-",
|
||||||
|
linewidth=3,
|
||||||
|
markersize=8,
|
||||||
|
label="Triton/V2 Speedup",
|
||||||
|
)
|
||||||
|
ax_speedup.set_title(
|
||||||
|
f"{strategy_name}\nSpeedup vs Baseline (Triton)\nE={E}, T={T}, H={H}",
|
||||||
|
fontsize=12,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_speedup.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||||
|
ax_speedup.set_ylabel("Speedup Ratio", fontweight="bold", fontsize=11)
|
||||||
|
ax_speedup.legend(prop={"weight": "bold"})
|
||||||
|
ax_speedup.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Plot bandwidth utilization (right plot)
|
||||||
|
ax_bandwidth.plot(
|
||||||
|
total_tokens_values,
|
||||||
|
v2_bandwidth_percentages,
|
||||||
|
"o-",
|
||||||
|
linewidth=3,
|
||||||
|
markersize=8,
|
||||||
|
label="SiLU V2",
|
||||||
|
color="blue",
|
||||||
|
)
|
||||||
|
ax_bandwidth.plot(
|
||||||
|
total_tokens_values,
|
||||||
|
triton_bandwidth_percentages,
|
||||||
|
"o-",
|
||||||
|
linewidth=3,
|
||||||
|
markersize=8,
|
||||||
|
label="Triton",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
|
ax_bandwidth.set_title(
|
||||||
|
f"{strategy_name}\nBandwidth Utilization (Hopper)\nE={E}, T={T}, H={H}",
|
||||||
|
fontsize=12,
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
ax_bandwidth.set_xlabel("Total Tokens", fontweight="bold", fontsize=11)
|
||||||
|
ax_bandwidth.set_ylabel(
|
||||||
|
"% of Peak Bandwidth", fontweight="bold", fontsize=11
|
||||||
|
)
|
||||||
|
ax_bandwidth.legend(prop={"weight": "bold"})
|
||||||
|
ax_bandwidth.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# Format x-axis labels for both plots
|
||||||
|
for ax in [ax_speedup, ax_bandwidth]:
|
||||||
|
ax.set_xticks(total_tokens_values)
|
||||||
|
ax.set_xticklabels(
|
||||||
|
[
|
||||||
|
f"{tt // 1000}K" if tt >= 1000 else str(tt)
|
||||||
|
for tt in total_tokens_values
|
||||||
|
],
|
||||||
|
fontweight="bold",
|
||||||
|
)
|
||||||
|
# Make tick labels bold
|
||||||
|
for label in ax.get_xticklabels() + ax.get_yticklabels():
|
||||||
|
label.set_fontweight("bold")
|
||||||
|
|
||||||
|
# Add value labels on Triton/V2 speedup points
|
||||||
|
for x, y in zip(total_tokens_values, triton_v2_ratios):
|
||||||
|
ax_speedup.annotate(
|
||||||
|
f"{y:.2f}x",
|
||||||
|
(x, y),
|
||||||
|
textcoords="offset points",
|
||||||
|
xytext=(0, -15),
|
||||||
|
ha="center",
|
||||||
|
fontsize=9,
|
||||||
|
fontweight="bold",
|
||||||
|
bbox=dict(boxstyle="round,pad=0.2", facecolor="green", alpha=0.3),
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.subplots_adjust(top=0.93) # Make room for main title
|
||||||
|
filename = "silu_benchmark_total_tokens_3way.png"
|
||||||
|
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
|
||||||
|
# Create comprehensive 3-way comparison plots
|
||||||
|
combined_plot_filename = create_combined_plot(all_results)
|
||||||
|
total_tokens_plot_filename = create_total_tokens_plot(all_results)
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("3-Way Benchmark Suite Complete!")
|
||||||
|
print(f"Generated combined comparison plot: {combined_plot_filename}")
|
||||||
|
print(f"Generated total tokens analysis plot: {total_tokens_plot_filename}")
|
||||||
|
print("Compared: SiLU V2 (CUDA), and Triton implementations")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
290
benchmarks/kernels/benchmark_trtllm_decode_attention.py
Normal file
290
benchmarks/kernels/benchmark_trtllm_decode_attention.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
FP8_DTYPE = torch.float8_e4m3fn
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax * 0.1
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def benchmark_decode(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||||
|
batch_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
num_heads: tuple[int, int] = (64, 8),
|
||||||
|
head_size: int = 128,
|
||||||
|
kv_layout: str = "HND",
|
||||||
|
block_size: int = 16,
|
||||||
|
warmup: int = 10,
|
||||||
|
trials: int = 20,
|
||||||
|
):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||||
|
q_quant_dtype = q_quant_dtype or dtype
|
||||||
|
kv_quant_dtype = kv_quant_dtype or dtype
|
||||||
|
o_quant_dtype = o_quant_dtype or dtype
|
||||||
|
|
||||||
|
num_qo_heads, num_kv_heads = num_heads
|
||||||
|
assert num_qo_heads % num_kv_heads == 0
|
||||||
|
|
||||||
|
sm_scale = float(1.0 / (head_size**0.5))
|
||||||
|
|
||||||
|
# large number to reduce kv_cache reuse
|
||||||
|
NUM_BLOCKS = int(256000 / block_size)
|
||||||
|
|
||||||
|
kv_cache_shape = None
|
||||||
|
if kv_layout == "NHD":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||||
|
elif kv_layout == "HND":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
|
||||||
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
q_scale = 1.0
|
||||||
|
ref_query = torch.randn(batch_size, num_qo_heads, head_size, dtype=dtype)
|
||||||
|
if q_quant_dtype == FP8_DTYPE:
|
||||||
|
query, _ = to_float8(ref_query)
|
||||||
|
else:
|
||||||
|
query = ref_query
|
||||||
|
|
||||||
|
kv_lens = torch.randint(1, max_seq_len, (batch_size,), dtype=torch.int32)
|
||||||
|
kv_lens[-1] = max_seq_len
|
||||||
|
|
||||||
|
seq_lens = kv_lens
|
||||||
|
max_seq_len = torch.max(seq_lens).item()
|
||||||
|
|
||||||
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
|
if kv_quant_dtype == FP8_DTYPE:
|
||||||
|
kv_cache, _ = to_float8(ref_kv_cache)
|
||||||
|
else:
|
||||||
|
kv_cache = ref_kv_cache
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(
|
||||||
|
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||||
|
)
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
|
||||||
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer,
|
||||||
|
kv_layout,
|
||||||
|
use_tensor_cores=True,
|
||||||
|
)
|
||||||
|
wrapper.plan(
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
"NONE",
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def time_fn(fn, warmup=10, trials=20):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = torch.Event(enable_timing=True)
|
||||||
|
end = torch.Event(enable_timing=True)
|
||||||
|
times = []
|
||||||
|
for i in range(warmup):
|
||||||
|
fn()
|
||||||
|
for i in range(trials):
|
||||||
|
start.record()
|
||||||
|
fn()
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times.append(start.elapsed_time(end)) # ms
|
||||||
|
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||||
|
|
||||||
|
o_scale = 1.0
|
||||||
|
o_sf_scale = None
|
||||||
|
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
|
if o_quant_dtype == FP4_DTYPE:
|
||||||
|
o_sf_scale = 500.0
|
||||||
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||||
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
round_up(query.shape[0], 128),
|
||||||
|
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||||
|
),
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||||
|
|
||||||
|
def baseline_decode():
|
||||||
|
return wrapper.run(
|
||||||
|
ref_query,
|
||||||
|
ref_kv_cache,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out=output_baseline,
|
||||||
|
)
|
||||||
|
|
||||||
|
def trtllm_decode():
|
||||||
|
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||||
|
query=query,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||||
|
bmm2_scale=v_scale / o_scale,
|
||||||
|
o_sf_scale=o_sf_scale,
|
||||||
|
out=output_trtllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||||
|
trtllm_mean, trtllm_std = time_fn(trtllm_decode)
|
||||||
|
|
||||||
|
# Calculate percentage speedup (positive means TRT is faster)
|
||||||
|
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:.3f}\t{trtllm_std.item():.3f}"
|
||||||
|
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return results for CSV writing
|
||||||
|
return {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"trtllm_mean": trtllm_mean,
|
||||||
|
"trtllm_std": trtllm_std.item(),
|
||||||
|
"baseline_mean": baseline_mean,
|
||||||
|
"baseline_std": baseline_std.item(),
|
||||||
|
"speedup_percent": speedup_percent,
|
||||||
|
"q_dtype": str(q_quant_dtype),
|
||||||
|
"kv_cache_dtype": str(kv_quant_dtype),
|
||||||
|
"output_dtype": str(o_quant_dtype),
|
||||||
|
"block_size": block_size,
|
||||||
|
"num_kv_heads": num_kv_heads,
|
||||||
|
"head_size": head_size,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_csv(results, filename=None):
|
||||||
|
"""Write benchmark results to CSV file."""
|
||||||
|
if filename is None:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||||
|
|
||||||
|
fieldnames = [
|
||||||
|
"batch_size",
|
||||||
|
"trtllm_mean",
|
||||||
|
"trtllm_std",
|
||||||
|
"baseline_mean",
|
||||||
|
"baseline_std",
|
||||||
|
"speedup_percent",
|
||||||
|
"q_dtype",
|
||||||
|
"kv_cache_dtype",
|
||||||
|
"output_dtype",
|
||||||
|
"block_size",
|
||||||
|
"num_kv_heads",
|
||||||
|
"head_size",
|
||||||
|
"max_seq_len",
|
||||||
|
]
|
||||||
|
|
||||||
|
file_exists = os.path.exists(filename)
|
||||||
|
|
||||||
|
with open(filename, "a", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
if not file_exists:
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
writer.writerow(result)
|
||||||
|
|
||||||
|
print(f"Results written to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||||
|
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
quant_dtypes = [
|
||||||
|
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||||
|
(None, None, None),
|
||||||
|
(None, FP8_DTYPE, None),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, None),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||||
|
]
|
||||||
|
|
||||||
|
for quant_dtype in quant_dtypes:
|
||||||
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||||
|
q_quant_dtype = q_quant_dtype or dtype
|
||||||
|
kv_quant_dtype = kv_quant_dtype or dtype
|
||||||
|
o_quant_dtype = o_quant_dtype or dtype
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||||
|
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||||
|
f"output_dtype: {o_quant_dtype}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||||
|
"baseline_std\tspeedup_percent"
|
||||||
|
)
|
||||||
|
for max_seq_len in max_seq_lens:
|
||||||
|
for bs in batch_sizes:
|
||||||
|
result = benchmark_decode(
|
||||||
|
dtype=dtype,
|
||||||
|
quant_dtypes=quant_dtype,
|
||||||
|
batch_size=bs,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
# Write all results to CSV
|
||||||
|
write_results_to_csv(all_results)
|
||||||
305
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
305
benchmarks/kernels/benchmark_trtllm_prefill_attention.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import flashinfer
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils.math_utils import round_up
|
||||||
|
|
||||||
|
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||||
|
FP8_DTYPE = torch.float8_e4m3fn
|
||||||
|
FP4_DTYPE = torch.uint8
|
||||||
|
|
||||||
|
|
||||||
|
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||||
|
finfo = torch.finfo(dtype)
|
||||||
|
min_val, max_val = x.aminmax()
|
||||||
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
|
scale = finfo.max / amax * 0.1
|
||||||
|
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||||
|
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def benchmark_prefill(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
quant_dtypes: tuple[torch.dtype | None, torch.dtype | None, torch.dtype | None],
|
||||||
|
batch_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
num_heads: tuple[int, int] = (64, 8),
|
||||||
|
head_size: int = 128,
|
||||||
|
kv_layout: str = "HND",
|
||||||
|
block_size: int = 16,
|
||||||
|
warmup: int = 10,
|
||||||
|
trials: int = 20,
|
||||||
|
):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtypes
|
||||||
|
q_quant_dtype = q_quant_dtype or dtype
|
||||||
|
kv_quant_dtype = kv_quant_dtype or dtype
|
||||||
|
o_quant_dtype = o_quant_dtype or dtype
|
||||||
|
|
||||||
|
max_q_len = max_kv_len = max_seq_len
|
||||||
|
|
||||||
|
num_qo_heads, num_kv_heads = num_heads
|
||||||
|
assert num_qo_heads % num_kv_heads == 0
|
||||||
|
|
||||||
|
sm_scale = float(1.0 / (head_size**0.5))
|
||||||
|
|
||||||
|
# large number to reduce kv_cache reuse
|
||||||
|
NUM_BLOCKS = int(256000 / block_size)
|
||||||
|
|
||||||
|
kv_cache_shape = None
|
||||||
|
if kv_layout == "NHD":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||||
|
elif kv_layout == "HND":
|
||||||
|
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||||
|
|
||||||
|
q_lens = torch.randint(1, max_q_len, (batch_size,), dtype=torch.int32)
|
||||||
|
q_lens[-1] = max_q_len
|
||||||
|
q_indptr = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([0], dtype=torch.int32),
|
||||||
|
torch.cumsum(q_lens, dim=0, dtype=torch.int32),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
q_scale = 1.0
|
||||||
|
ref_query = torch.randn(
|
||||||
|
torch.sum(q_lens).item(), num_qo_heads, head_size, dtype=dtype
|
||||||
|
)
|
||||||
|
if q_quant_dtype == FP8_DTYPE:
|
||||||
|
query, _ = to_float8(ref_query)
|
||||||
|
else:
|
||||||
|
query = ref_query
|
||||||
|
|
||||||
|
kv_lens = torch.randint(0, max_kv_len, (batch_size,), dtype=torch.int32)
|
||||||
|
kv_lens[-1] = max_kv_len
|
||||||
|
|
||||||
|
seq_lens = kv_lens + q_lens
|
||||||
|
max_seq_len = torch.max(seq_lens).item()
|
||||||
|
|
||||||
|
# Always using 1.0 scale to reflect the real perf in benchmarking
|
||||||
|
k_scale = v_scale = 1.0
|
||||||
|
ref_kv_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||||
|
if kv_quant_dtype == FP8_DTYPE:
|
||||||
|
kv_cache, _ = to_float8(ref_kv_cache)
|
||||||
|
else:
|
||||||
|
kv_cache = ref_kv_cache
|
||||||
|
|
||||||
|
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||||
|
block_tables = torch.randint(
|
||||||
|
0, NUM_BLOCKS, (batch_size, max_num_blocks_per_seq), dtype=torch.int32
|
||||||
|
)
|
||||||
|
kv_indptr = [0]
|
||||||
|
kv_indices = []
|
||||||
|
kv_last_page_lens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
seq_len = seq_lens[i]
|
||||||
|
assert seq_len > 0
|
||||||
|
num_blocks = (seq_len + block_size - 1) // block_size
|
||||||
|
kv_indices.extend(block_tables[i, :num_blocks])
|
||||||
|
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||||
|
kv_last_page_len = seq_len % block_size
|
||||||
|
if kv_last_page_len == 0:
|
||||||
|
kv_last_page_len = block_size
|
||||||
|
kv_last_page_lens.append(kv_last_page_len)
|
||||||
|
|
||||||
|
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||||
|
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||||
|
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||||
|
workspace_buffer = torch.zeros(1024 * 1024 * 1024, dtype=torch.int8)
|
||||||
|
|
||||||
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
|
||||||
|
workspace_buffer, kv_layout
|
||||||
|
)
|
||||||
|
wrapper.plan(
|
||||||
|
q_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
kv_last_page_lens,
|
||||||
|
num_qo_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_size,
|
||||||
|
block_size,
|
||||||
|
causal=True,
|
||||||
|
sm_scale=sm_scale,
|
||||||
|
q_data_type=dtype,
|
||||||
|
kv_data_type=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
def time_fn(fn, warmup=10, trials=20):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = torch.Event(enable_timing=True)
|
||||||
|
end = torch.Event(enable_timing=True)
|
||||||
|
times = []
|
||||||
|
for i in range(warmup):
|
||||||
|
fn()
|
||||||
|
for i in range(trials):
|
||||||
|
start.record()
|
||||||
|
fn()
|
||||||
|
end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
times.append(start.elapsed_time(end)) # ms
|
||||||
|
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||||
|
|
||||||
|
o_scale = 1.0
|
||||||
|
o_sf_scale = None
|
||||||
|
output_baseline = torch.empty(ref_query.shape, dtype=dtype)
|
||||||
|
if o_quant_dtype == FP4_DTYPE:
|
||||||
|
o_sf_scale = 500.0
|
||||||
|
output_trtllm = flashinfer.utils.FP4Tensor(
|
||||||
|
torch.empty(query.shape[:-1] + (query.shape[-1] // 2,), dtype=torch.uint8),
|
||||||
|
torch.empty(
|
||||||
|
(
|
||||||
|
round_up(query.shape[0], 128),
|
||||||
|
round_up(query.shape[1] * query.shape[2] // 16, 4),
|
||||||
|
),
|
||||||
|
dtype=torch.float8_e4m3fn,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_trtllm = torch.empty(query.shape, dtype=o_quant_dtype)
|
||||||
|
|
||||||
|
def baseline_prefill():
|
||||||
|
return wrapper.run(
|
||||||
|
ref_query,
|
||||||
|
ref_kv_cache,
|
||||||
|
k_scale=k_scale,
|
||||||
|
v_scale=v_scale,
|
||||||
|
out=output_baseline,
|
||||||
|
)
|
||||||
|
|
||||||
|
def trtllm_prefill():
|
||||||
|
return flashinfer.prefill.trtllm_batch_context_with_kv_cache(
|
||||||
|
query=query,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
workspace_buffer=workspace_buffer,
|
||||||
|
block_tables=block_tables,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
max_q_len=max_q_len,
|
||||||
|
max_kv_len=max_seq_len,
|
||||||
|
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||||
|
bmm2_scale=v_scale / o_scale,
|
||||||
|
batch_size=batch_size,
|
||||||
|
cum_seq_lens_q=q_indptr,
|
||||||
|
cum_seq_lens_kv=kv_indptr,
|
||||||
|
o_sf_scale=o_sf_scale,
|
||||||
|
out=output_trtllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
baseline_mean, baseline_std = time_fn(baseline_prefill)
|
||||||
|
trtllm_mean, trtllm_std = time_fn(trtllm_prefill)
|
||||||
|
|
||||||
|
# Calculate percentage speedup (positive means TRT is faster)
|
||||||
|
speedup_percent = (baseline_mean - trtllm_mean) / baseline_mean
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"\t{batch_size}\t{max_seq_len}\t{trtllm_mean:8.3f}\t{trtllm_std.item():8.3f}"
|
||||||
|
f"\t{baseline_mean:8.3f}\t{baseline_std.item():8.3f}\t{speedup_percent:8.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return results for CSV writing
|
||||||
|
return {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"trtllm_mean": trtllm_mean,
|
||||||
|
"trtllm_std": trtllm_std.item(),
|
||||||
|
"baseline_mean": baseline_mean,
|
||||||
|
"baseline_std": baseline_std.item(),
|
||||||
|
"speedup_percent": speedup_percent,
|
||||||
|
"q_dtype": str(q_quant_dtype),
|
||||||
|
"kv_cache_dtype": str(kv_quant_dtype),
|
||||||
|
"output_dtype": str(o_quant_dtype),
|
||||||
|
"block_size": block_size,
|
||||||
|
"num_kv_heads": num_kv_heads,
|
||||||
|
"head_size": head_size,
|
||||||
|
"max_seq_len": max_seq_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_csv(results, filename=None):
|
||||||
|
"""Write benchmark results to CSV file."""
|
||||||
|
if filename is None:
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||||
|
|
||||||
|
fieldnames = [
|
||||||
|
"batch_size",
|
||||||
|
"trtllm_mean",
|
||||||
|
"trtllm_std",
|
||||||
|
"baseline_mean",
|
||||||
|
"baseline_std",
|
||||||
|
"speedup_percent",
|
||||||
|
"q_dtype",
|
||||||
|
"kv_cache_dtype",
|
||||||
|
"output_dtype",
|
||||||
|
"block_size",
|
||||||
|
"num_kv_heads",
|
||||||
|
"head_size",
|
||||||
|
"max_seq_len",
|
||||||
|
]
|
||||||
|
|
||||||
|
file_exists = os.path.exists(filename)
|
||||||
|
|
||||||
|
with open(filename, "a", newline="") as csvfile:
|
||||||
|
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||||
|
|
||||||
|
if not file_exists:
|
||||||
|
writer.writeheader()
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
writer.writerow(result)
|
||||||
|
|
||||||
|
print(f"Results written to {filename}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
batch_sizes = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||||
|
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||||
|
all_results = []
|
||||||
|
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
quant_dtypes = [
|
||||||
|
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||||
|
(None, None, None),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, None),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||||
|
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||||
|
]
|
||||||
|
|
||||||
|
for quant_dtype in quant_dtypes:
|
||||||
|
q_quant_dtype, kv_quant_dtype, o_quant_dtype = quant_dtype
|
||||||
|
q_quant_dtype = q_quant_dtype or dtype
|
||||||
|
kv_quant_dtype = kv_quant_dtype or dtype
|
||||||
|
o_quant_dtype = o_quant_dtype or dtype
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Running benchmark for q_dtype = {q_quant_dtype}, "
|
||||||
|
f"kv_cache_dtype: {kv_quant_dtype}, "
|
||||||
|
f"output_dtype: {o_quant_dtype}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"\tbatch_size\tmax_seq_len\ttrtllm_mean\ttrtllm_std\tbaseline_mean\t"
|
||||||
|
"baseline_std\tspeedup_percent"
|
||||||
|
)
|
||||||
|
for max_seq_len in max_seq_lens:
|
||||||
|
for bs in batch_sizes:
|
||||||
|
result = benchmark_prefill(
|
||||||
|
dtype=dtype,
|
||||||
|
quant_dtypes=quant_dtype,
|
||||||
|
batch_size=bs,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
# Write all results to CSV
|
||||||
|
write_results_to_csv(all_results)
|
||||||
415
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
415
benchmarks/kernels/benchmark_w8a8_block_fp8.py
Normal file
@@ -0,0 +1,415 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# Adapted from sglang quantization/tuning_block_wise_kernel.py
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import multiprocessing as mp
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
_w8a8_triton_block_scaled_mm,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
|
||||||
|
assert current_platform.is_cuda(), (
|
||||||
|
"Only support tune w8a8 block fp8 kernel on CUDA device."
|
||||||
|
)
|
||||||
|
|
||||||
|
DTYPE_MAP = {
|
||||||
|
"float32": torch.float32,
|
||||||
|
"float16": torch.float16,
|
||||||
|
"half": torch.half,
|
||||||
|
"bfloat16": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def w8a8_block_matmul(
|
||||||
|
A: torch.Tensor,
|
||||||
|
B: torch.Tensor,
|
||||||
|
As: torch.Tensor,
|
||||||
|
Bs: torch.Tensor,
|
||||||
|
block_size: list[int],
|
||||||
|
config: dict[str, Any],
|
||||||
|
output_dtype: torch.dtype = torch.float16,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""This function performs matrix multiplication with
|
||||||
|
block-wise quantization.
|
||||||
|
|
||||||
|
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||||
|
The output is returned in the specified `output_dtype`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
A: The input tensor, e.g., activation.
|
||||||
|
B: The input tensor, e.g., weight.
|
||||||
|
As: The per-token-group quantization scale for `A`.
|
||||||
|
Bs: The per-block quantization scale for `B`.
|
||||||
|
block_size: The block size for per-block quantization.
|
||||||
|
It should be 2-dim, e.g., [128, 128].
|
||||||
|
output_dtype: The dtype of the returned tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The result of matmul.
|
||||||
|
"""
|
||||||
|
assert len(block_size) == 2
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
|
||||||
|
assert A.shape[-1] == B.shape[-1]
|
||||||
|
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||||
|
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||||
|
M = A.numel() // A.shape[-1]
|
||||||
|
|
||||||
|
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||||
|
N, K = B.shape
|
||||||
|
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||||
|
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||||
|
|
||||||
|
C_shape = A.shape[:-1] + (N,)
|
||||||
|
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||||
|
|
||||||
|
def grid(META):
|
||||||
|
return (
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
if A.dtype == torch.float8_e4m3fn:
|
||||||
|
kernel = _w8a8_triton_block_scaled_mm
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
|
kernel[grid](
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
A.stride(-2),
|
||||||
|
A.stride(-1),
|
||||||
|
B.stride(1),
|
||||||
|
B.stride(0),
|
||||||
|
C.stride(-2),
|
||||||
|
C.stride(-1),
|
||||||
|
As.stride(-2),
|
||||||
|
As.stride(-1),
|
||||||
|
Bs.stride(1),
|
||||||
|
Bs.stride(0),
|
||||||
|
**config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return C
|
||||||
|
|
||||||
|
|
||||||
|
def get_configs_compute_bound():
|
||||||
|
configs = []
|
||||||
|
for num_stages in [2, 3, 4, 5]:
|
||||||
|
for block_m in [16, 32, 64, 128, 256]:
|
||||||
|
for block_k in [64, 128]:
|
||||||
|
for block_n in [32, 64, 128, 256]:
|
||||||
|
for num_warps in [4, 8]:
|
||||||
|
for group_size in [1, 16, 32, 64]:
|
||||||
|
configs.append(
|
||||||
|
{
|
||||||
|
"BLOCK_SIZE_M": block_m,
|
||||||
|
"BLOCK_SIZE_N": block_n,
|
||||||
|
"BLOCK_SIZE_K": block_k,
|
||||||
|
"GROUP_SIZE_M": group_size,
|
||||||
|
"num_warps": num_warps,
|
||||||
|
"num_stages": num_stages,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return configs
|
||||||
|
|
||||||
|
|
||||||
|
def get_weight_shapes(tp_size):
|
||||||
|
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3.
|
||||||
|
# Modify them, if you tune for another different model.
|
||||||
|
# cannot TP
|
||||||
|
total = [
|
||||||
|
(512 + 64, 7168),
|
||||||
|
(2112, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(7168, 16384),
|
||||||
|
(7168, 18432),
|
||||||
|
]
|
||||||
|
# N can TP
|
||||||
|
n_tp = [
|
||||||
|
(18432 * 2, 7168),
|
||||||
|
((128 + 64) * 128, 7168),
|
||||||
|
(128 * (128 + 128), 512),
|
||||||
|
(24576, 1536),
|
||||||
|
(12288, 7168),
|
||||||
|
(4096, 7168),
|
||||||
|
]
|
||||||
|
# K can TP
|
||||||
|
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
||||||
|
|
||||||
|
weight_shapes = []
|
||||||
|
for t in total:
|
||||||
|
weight_shapes.append(t)
|
||||||
|
for n_t in n_tp:
|
||||||
|
new_t = (n_t[0] // tp_size, n_t[1])
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
for k_t in k_tp:
|
||||||
|
new_t = (k_t[0], k_t[1] // tp_size)
|
||||||
|
weight_shapes.append(new_t)
|
||||||
|
return weight_shapes
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_config(
|
||||||
|
A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
|
||||||
|
):
|
||||||
|
def run():
|
||||||
|
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
# JIT complication & warmup
|
||||||
|
for _ in range(5):
|
||||||
|
run()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
start_event = torch.Event(enable_timing=True)
|
||||||
|
end_event = torch.Event(enable_timing=True)
|
||||||
|
|
||||||
|
latencies: list[float] = []
|
||||||
|
for i in range(num_iters):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start_event.record()
|
||||||
|
run()
|
||||||
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
latencies.append(start_event.elapsed_time(end_event))
|
||||||
|
avg = sum(latencies) / (num_iters * 10) * 1000 # us
|
||||||
|
return avg
|
||||||
|
|
||||||
|
|
||||||
|
def tune(M, N, K, block_size, out_dtype, search_space, input_type):
|
||||||
|
factor_for_scale = 1e-2
|
||||||
|
|
||||||
|
if input_type == "fp8":
|
||||||
|
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||||
|
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||||
|
|
||||||
|
A_fp32 = (
|
||||||
|
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||||
|
)
|
||||||
|
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
B_fp32 = (
|
||||||
|
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
|
||||||
|
)
|
||||||
|
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
|
||||||
|
|
||||||
|
block_n, block_k = block_size[0], block_size[1]
|
||||||
|
n_tiles = (N + block_n - 1) // block_n
|
||||||
|
k_tiles = (K + block_k - 1) // block_k
|
||||||
|
|
||||||
|
As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
|
||||||
|
Bs = (
|
||||||
|
torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
|
||||||
|
* factor_for_scale
|
||||||
|
)
|
||||||
|
|
||||||
|
best_config = None
|
||||||
|
best_time = float("inf")
|
||||||
|
for config in tqdm(search_space):
|
||||||
|
try:
|
||||||
|
kernel_time = benchmark_config(
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
As,
|
||||||
|
Bs,
|
||||||
|
block_size,
|
||||||
|
config,
|
||||||
|
out_dtype,
|
||||||
|
num_iters=10,
|
||||||
|
)
|
||||||
|
except triton.runtime.autotuner.OutOfResources:
|
||||||
|
# Some configurations may be invalid and fail to compile.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if kernel_time < best_time:
|
||||||
|
best_time = kernel_time
|
||||||
|
best_config = config
|
||||||
|
now = datetime.now()
|
||||||
|
print(f"{now.ctime()}] Completed tuning for batch_size={M}")
|
||||||
|
assert best_config is not None
|
||||||
|
return best_config
|
||||||
|
|
||||||
|
|
||||||
|
def save_configs(
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_n,
|
||||||
|
block_k,
|
||||||
|
configs,
|
||||||
|
save_path,
|
||||||
|
input_type="fp8",
|
||||||
|
) -> None:
|
||||||
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
device_name = current_platform.get_device_name().replace(" ", "_")
|
||||||
|
json_file_name = (
|
||||||
|
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
|
||||||
|
f"block_shape=[{block_n},{block_k}].json"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_file_path = os.path.join(save_path, json_file_name)
|
||||||
|
print(f"Writing best config to {config_file_path}...")
|
||||||
|
|
||||||
|
with open(config_file_path, "w") as f:
|
||||||
|
json.dump(configs, f, indent=4)
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
|
||||||
|
def tune_on_gpu(args_dict):
|
||||||
|
"""Run tuning on a specific GPU."""
|
||||||
|
gpu_id = args_dict["gpu_id"]
|
||||||
|
batch_sizes = args_dict["batch_sizes"]
|
||||||
|
weight_shapes = args_dict["weight_shapes"]
|
||||||
|
args = args_dict["args"]
|
||||||
|
|
||||||
|
torch.cuda.set_device(gpu_id)
|
||||||
|
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
|
||||||
|
|
||||||
|
block_n = args.block_n
|
||||||
|
block_k = args.block_k
|
||||||
|
out_dtype = DTYPE_MAP[args.out_dtype]
|
||||||
|
save_path = args.save_path
|
||||||
|
input_type = args.input_type
|
||||||
|
|
||||||
|
search_space = get_configs_compute_bound()
|
||||||
|
search_space = [
|
||||||
|
config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for shape in tqdm(weight_shapes, desc=f"GPU {gpu_id} - Shapes"):
|
||||||
|
N, K = shape[0], shape[1]
|
||||||
|
print(f"[GPU {gpu_id}] Tune for weight shape of `N: {N}, K: {K}`")
|
||||||
|
benchmark_results = [
|
||||||
|
tune(
|
||||||
|
batch_size,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
[block_n, block_k],
|
||||||
|
out_dtype,
|
||||||
|
search_space,
|
||||||
|
input_type,
|
||||||
|
)
|
||||||
|
for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
|
||||||
|
]
|
||||||
|
best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
|
||||||
|
save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_batch_sizes(batch_sizes, num_gpus):
|
||||||
|
"""Distribute batch sizes across available GPUs."""
|
||||||
|
batches_per_gpu = []
|
||||||
|
for i in range(num_gpus):
|
||||||
|
start_idx = i * len(batch_sizes) // num_gpus
|
||||||
|
end_idx = (i + 1) * len(batch_sizes) // num_gpus
|
||||||
|
batches_per_gpu.append(batch_sizes[start_idx:end_idx])
|
||||||
|
return batches_per_gpu
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
if num_gpus == 0:
|
||||||
|
raise RuntimeError("No GPU available for tuning")
|
||||||
|
print(f"Found {num_gpus} GPUs for parallel tuning")
|
||||||
|
|
||||||
|
torch.cuda.init()
|
||||||
|
|
||||||
|
if args.batch_size is None:
|
||||||
|
batch_sizes = [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
8,
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
32,
|
||||||
|
48,
|
||||||
|
64,
|
||||||
|
96,
|
||||||
|
128,
|
||||||
|
256,
|
||||||
|
512,
|
||||||
|
1024,
|
||||||
|
1536,
|
||||||
|
2048,
|
||||||
|
3072,
|
||||||
|
4096,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch_sizes = [args.batch_size]
|
||||||
|
num_gpus = 1 # If only one batch size, use only one GPU
|
||||||
|
|
||||||
|
weight_shapes = get_weight_shapes(args.tp_size)
|
||||||
|
|
||||||
|
batches_per_gpu = distribute_batch_sizes(batch_sizes, num_gpus)
|
||||||
|
|
||||||
|
process_args = []
|
||||||
|
for gpu_id in range(num_gpus):
|
||||||
|
process_args.append(
|
||||||
|
{
|
||||||
|
"gpu_id": gpu_id,
|
||||||
|
"batch_sizes": batches_per_gpu[gpu_id],
|
||||||
|
"weight_shapes": weight_shapes, # Each GPU processes all weight shapes
|
||||||
|
"args": args,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
with ctx.Pool(num_gpus) as pool:
|
||||||
|
pool.map(tune_on_gpu, process_args)
|
||||||
|
|
||||||
|
print("Multi-GPU tuning completed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="""
|
||||||
|
Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
|
||||||
|
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
|
||||||
|
Then copy to model_executor/layers/quantization/utils/configs
|
||||||
|
""",
|
||||||
|
formatter_class=argparse.RawTextHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--tp-size", "-tp", type=int, default=8)
|
||||||
|
parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
|
||||||
|
parser.add_argument(
|
||||||
|
"--out-dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["float32", "float16", "bfloat16", "half"],
|
||||||
|
default="float16",
|
||||||
|
)
|
||||||
|
parser.add_argument("--block-n", type=int, default=128)
|
||||||
|
parser.add_argument("--block-k", type=int, default=128)
|
||||||
|
parser.add_argument("--batch-size", type=int, required=False)
|
||||||
|
parser.add_argument("--save-path", type=str, default="./")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
||||||
129
benchmarks/kernels/deepgemm/README.md
Normal file
129
benchmarks/kernels/deepgemm/README.md
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# DeepSeek DeepGEMM Kernels Benchmark
|
||||||
|
|
||||||
|
This directory includes benchmarks between DeepSeek's DeepGEMM block fp8 kernels against vLLM's existing triton and CUTLASS-based kernels.
|
||||||
|
|
||||||
|
Currently, this just includes dense GEMMs and only works on Hopper GPUs.
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
You need to install vLLM in your usual fashion, then install DeepGEMM from source in its own directory:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone --recursive https://github.com/deepseek-ai/DeepGEMM
|
||||||
|
cd DeepGEMM
|
||||||
|
python setup.py install
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```console
|
||||||
|
python benchmark_fp8_block_dense_gemm.py
|
||||||
|
INFO 02-26 21:55:13 [__init__.py:207] Automatically detected platform cuda.
|
||||||
|
===== STARTING FP8 GEMM BENCHMARK =====
|
||||||
|
PyTorch version: 2.5.1+cu124
|
||||||
|
CUDA version: 12.4
|
||||||
|
Triton version: 3.1.0
|
||||||
|
Using device: NVIDIA H100 80GB HBM3
|
||||||
|
WARNING 02-26 21:55:15 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||||
|
INFO 02-26 21:55:15 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||||
|
WARNING 02-26 21:55:16 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=18432,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||||
|
WARNING 02-26 21:55:17 [fp8_utils.py:458] Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! Config file not found at /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json
|
||||||
|
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||||
|
INFO 02-26 21:55:17 [fp8_utils.py:449] Using configuration from /home/mgoin/code/vllm/vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json for W8A8 Block FP8 kernel.
|
||||||
|
|
||||||
|
===== PERFORMANCE COMPARISON =====
|
||||||
|
|
||||||
|
DeepGEMM Implementation:
|
||||||
|
+------+-------+-------+-----------+--------+--------+
|
||||||
|
| m | n | k | Time (μs) | TFLOPS | GB/s |
|
||||||
|
+------+-------+-------+-----------+--------+--------+
|
||||||
|
| 8 | 4096 | 7168 | 102.9 | 4.6 | 286.4 |
|
||||||
|
| 8 | 7168 | 18432 | 70.8 | 29.8 | 1868.8 |
|
||||||
|
| 8 | 18432 | 7168 | 69.3 | 30.5 | 1911.8 |
|
||||||
|
| 64 | 4096 | 7168 | 69.1 | 54.4 | 439.0 |
|
||||||
|
| 64 | 7168 | 18432 | 69.4 | 243.6 | 1933.6 |
|
||||||
|
| 64 | 18432 | 7168 | 70.4 | 240.3 | 1917.2 |
|
||||||
|
| 64 | 24576 | 1536 | 70.1 | 68.9 | 584.6 |
|
||||||
|
| 64 | 32768 | 512 | 68.4 | 31.4 | 307.1 |
|
||||||
|
| 64 | 7168 | 16384 | 69.5 | 216.3 | 1718.5 |
|
||||||
|
| 128 | 4096 | 7168 | 141.1 | 53.3 | 222.1 |
|
||||||
|
| 128 | 7168 | 18432 | 71.9 | 470.5 | 1896.1 |
|
||||||
|
| 128 | 18432 | 7168 | 69.3 | 488.2 | 1988.2 |
|
||||||
|
| 1024 | 4096 | 7168 | 89.7 | 670.1 | 502.5 |
|
||||||
|
| 1024 | 18432 | 7168 | 279.0 | 969.8 | 635.2 |
|
||||||
|
| 2048 | 4096 | 7168 | 175.1 | 687.0 | 347.4 |
|
||||||
|
| 4096 | 4096 | 7168 | 335.4 | 717.0 | 275.1 |
|
||||||
|
+------+-------+-------+-----------+--------+--------+
|
||||||
|
|
||||||
|
vLLM Triton Implementation:
|
||||||
|
+------+-------+-------+-----------+--------+--------+--------------+
|
||||||
|
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM |
|
||||||
|
+------+-------+-------+-----------+--------+--------+--------------+
|
||||||
|
| 8 | 4096 | 7168 | 74.0 | 6.3 | 398.2 | 1.39x faster |
|
||||||
|
| 8 | 7168 | 18432 | 89.6 | 23.6 | 1478.1 | 0.79x slower |
|
||||||
|
| 8 | 18432 | 7168 | 113.2 | 18.7 | 1170.4 | 0.61x slower |
|
||||||
|
| 64 | 4096 | 7168 | 79.4 | 47.3 | 382.2 | 0.87x slower |
|
||||||
|
| 64 | 7168 | 18432 | 98.5 | 171.7 | 1363.0 | 0.70x slower |
|
||||||
|
| 64 | 18432 | 7168 | 119.5 | 141.5 | 1129.4 | 0.59x slower |
|
||||||
|
| 64 | 24576 | 1536 | 37.6 | 128.4 | 1089.7 | 1.86x faster |
|
||||||
|
| 64 | 32768 | 512 | 38.7 | 55.5 | 542.6 | 1.77x faster |
|
||||||
|
| 64 | 7168 | 16384 | 86.1 | 174.5 | 1386.4 | 0.81x slower |
|
||||||
|
| 128 | 4096 | 7168 | 90.7 | 82.9 | 345.4 | 1.56x faster |
|
||||||
|
| 128 | 7168 | 18432 | 144.0 | 234.9 | 946.9 | 0.50x slower |
|
||||||
|
| 128 | 18432 | 7168 | 229.5 | 147.4 | 600.1 | 0.30x slower |
|
||||||
|
| 1024 | 4096 | 7168 | 242.3 | 248.2 | 186.1 | 0.37x slower |
|
||||||
|
| 1024 | 18432 | 7168 | 897.8 | 301.4 | 197.4 | 0.31x slower |
|
||||||
|
| 2048 | 4096 | 7168 | 463.0 | 259.7 | 131.4 | 0.38x slower |
|
||||||
|
| 4096 | 4096 | 7168 | 901.8 | 266.7 | 102.3 | 0.37x slower |
|
||||||
|
+------+-------+-------+-----------+--------+--------+--------------+
|
||||||
|
|
||||||
|
vLLM CUTLASS Implementation:
|
||||||
|
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||||
|
| m | n | k | Time (μs) | TFLOPS | GB/s | vs DeepGEMM | vs Triton |
|
||||||
|
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||||
|
| 8 | 4096 | 7168 | 34.6 | 13.6 | 852.3 | 2.98x faster | 2.14x faster |
|
||||||
|
| 8 | 7168 | 18432 | 78.9 | 26.8 | 1677.3 | 0.90x slower | 1.13x faster |
|
||||||
|
| 8 | 18432 | 7168 | 81.2 | 26.0 | 1631.1 | 0.85x slower | 1.39x faster |
|
||||||
|
| 64 | 4096 | 7168 | 36.9 | 101.9 | 822.9 | 1.87x faster | 2.15x faster |
|
||||||
|
| 64 | 7168 | 18432 | 87.4 | 193.4 | 1535.2 | 0.79x slower | 1.13x faster |
|
||||||
|
| 64 | 18432 | 7168 | 85.0 | 199.0 | 1587.6 | 0.83x slower | 1.41x faster |
|
||||||
|
| 64 | 24576 | 1536 | 28.0 | 172.8 | 1465.8 | 2.51x faster | 1.35x faster |
|
||||||
|
| 64 | 32768 | 512 | 28.8 | 74.5 | 728.5 | 2.37x faster | 1.34x faster |
|
||||||
|
| 64 | 7168 | 16384 | 77.9 | 193.0 | 1532.8 | 0.89x slower | 1.11x faster |
|
||||||
|
| 128 | 4096 | 7168 | 39.1 | 192.4 | 802.0 | 3.61x faster | 2.32x faster |
|
||||||
|
| 128 | 7168 | 18432 | 93.7 | 360.8 | 1454.2 | 0.77x slower | 1.54x faster |
|
||||||
|
| 128 | 18432 | 7168 | 85.7 | 394.8 | 1608.0 | 0.81x slower | 2.68x faster |
|
||||||
|
| 1024 | 4096 | 7168 | 99.7 | 603.1 | 452.2 | 0.90x slower | 2.43x faster |
|
||||||
|
| 1024 | 18432 | 7168 | 331.3 | 816.7 | 534.9 | 0.84x slower | 2.71x faster |
|
||||||
|
| 2048 | 4096 | 7168 | 198.3 | 606.6 | 306.7 | 0.88x slower | 2.34x faster |
|
||||||
|
| 4096 | 4096 | 7168 | 392.2 | 613.2 | 235.3 | 0.86x slower | 2.30x faster |
|
||||||
|
+------+-------+-------+-----------+--------+--------+--------------+--------------+
|
||||||
|
|
||||||
|
===== AVERAGE PERFORMANCE =====
|
||||||
|
+----------------+------------+----------+---------------+
|
||||||
|
| Implementation | Avg TFLOPS | Avg GB/s | Avg Time (ms) |
|
||||||
|
+----------------+------------+----------+---------------+
|
||||||
|
| DeepGEMM | 310.98 | 1052.10 | 0.11 |
|
||||||
|
| vLLM Triton | 144.30 | 715.60 | 0.23 |
|
||||||
|
| vLLM CUTLASS | 286.78 | 1076.67 | 0.11 |
|
||||||
|
+----------------+------------+----------+---------------+
|
||||||
|
|
||||||
|
===== AVERAGE SPEEDUPS =====
|
||||||
|
+-----------------------------+--------------+
|
||||||
|
| Comparison | Speedup |
|
||||||
|
+-----------------------------+--------------+
|
||||||
|
| DeepGEMM vs vLLM Triton | 1.71x faster |
|
||||||
|
| DeepGEMM vs vLLM CUTLASS | 0.94x slower |
|
||||||
|
| vLLM CUTLASS vs vLLM Triton | 1.84x faster |
|
||||||
|
+-----------------------------+--------------+
|
||||||
|
|
||||||
|
===== ACCURACY COMPARISON =====
|
||||||
|
+----------------+-----------------------+
|
||||||
|
| Implementation | Avg Diff vs Reference |
|
||||||
|
+----------------+-----------------------+
|
||||||
|
| DeepGEMM | 0.000684 |
|
||||||
|
| vLLM Triton | 0.000684 |
|
||||||
|
| vLLM CUTLASS | 0.000684 |
|
||||||
|
+----------------+-----------------------+
|
||||||
|
```
|
||||||
435
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
435
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
# ruff: noqa: E501
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
|
per_token_group_quant_fp8,
|
||||||
|
w8a8_triton_block_scaled_mm,
|
||||||
|
)
|
||||||
|
from vllm.triton_utils import triton
|
||||||
|
from vllm.utils.deep_gemm import (
|
||||||
|
calc_diff,
|
||||||
|
fp8_gemm_nt,
|
||||||
|
get_col_major_tma_aligned_tensor,
|
||||||
|
per_block_cast_to_fp8,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_shape(
|
||||||
|
m: int,
|
||||||
|
n: int,
|
||||||
|
k: int,
|
||||||
|
warmup: int = 100,
|
||||||
|
repeat: int = 10000,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""Benchmark all implementations for a specific (m, n, k) shape."""
|
||||||
|
if verbose:
|
||||||
|
print(f"\n=== Benchmarking shape: m={m}, n={n}, k={k} ===")
|
||||||
|
|
||||||
|
# Create test tensors
|
||||||
|
A = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Reference result in BF16
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
C_ref = A @ B.t()
|
||||||
|
|
||||||
|
# Pre-quantize B for all implementations
|
||||||
|
# (weights can be pre-quantized offline)
|
||||||
|
B_deepgemm, B_scale_deepgemm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
|
||||||
|
B_vllm, B_scale_vllm = per_block_cast_to_fp8(B, [128, 128], use_ue8m0=True)
|
||||||
|
|
||||||
|
# Block size configuration
|
||||||
|
block_size = [128, 128]
|
||||||
|
|
||||||
|
# Pre-quantize A for all implementations
|
||||||
|
A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(A, block_size[1])
|
||||||
|
A_scale_deepgemm = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||||
|
C_deepgemm = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
|
||||||
|
A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||||
|
A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||||
|
A, block_size[1], column_major_scales=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# === DeepGEMM Implementation ===
|
||||||
|
def deepgemm_gemm():
|
||||||
|
fp8_gemm_nt(
|
||||||
|
(A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm
|
||||||
|
)
|
||||||
|
return C_deepgemm
|
||||||
|
|
||||||
|
# === vLLM Triton Implementation ===
|
||||||
|
def vllm_triton_gemm():
|
||||||
|
return w8a8_triton_block_scaled_mm(
|
||||||
|
A_vllm,
|
||||||
|
B_vllm,
|
||||||
|
A_scale_vllm,
|
||||||
|
B_scale_vllm,
|
||||||
|
block_size,
|
||||||
|
output_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# === vLLM CUTLASS Implementation ===
|
||||||
|
def vllm_cutlass_gemm():
|
||||||
|
return ops.cutlass_scaled_mm(
|
||||||
|
A_vllm_cutlass,
|
||||||
|
B_vllm.T,
|
||||||
|
scale_a=A_scale_vllm_cutlass,
|
||||||
|
scale_b=B_scale_vllm.T,
|
||||||
|
out_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run correctness check first
|
||||||
|
if verbose:
|
||||||
|
print("Running correctness check...")
|
||||||
|
C_deepgemm = deepgemm_gemm()
|
||||||
|
C_vllm_triton = vllm_triton_gemm()
|
||||||
|
C_vllm_cutlass = vllm_cutlass_gemm()
|
||||||
|
|
||||||
|
deepgemm_diff = calc_diff(C_deepgemm, C_ref)
|
||||||
|
vllm_triton_diff = calc_diff(C_vllm_triton, C_ref)
|
||||||
|
vllm_cutlass_diff = calc_diff(C_vllm_cutlass, C_ref)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"DeepGEMM vs Reference difference: {deepgemm_diff:.6f}")
|
||||||
|
print(f"vLLM Triton vs Reference difference: {vllm_triton_diff:.6f}")
|
||||||
|
print(f"vLLM CUTLASS vs Reference difference: {vllm_cutlass_diff:.6f}")
|
||||||
|
print(
|
||||||
|
"vLLM Triton vs DeepGEMM difference: "
|
||||||
|
f"{calc_diff(C_vllm_triton, C_deepgemm):.6f}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"vLLM CUTLASS vs DeepGEMM difference: "
|
||||||
|
f"{calc_diff(C_vllm_cutlass, C_deepgemm):.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Benchmark implementations
|
||||||
|
implementations = {
|
||||||
|
"DeepGEMM": deepgemm_gemm,
|
||||||
|
"vLLM Triton": vllm_triton_gemm,
|
||||||
|
"vLLM CUTLASS": vllm_cutlass_gemm,
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark_results = {"shape": {"m": m, "n": n, "k": k}, "implementations": {}}
|
||||||
|
|
||||||
|
for name, func in implementations.items():
|
||||||
|
# Warmup
|
||||||
|
for _ in range(warmup):
|
||||||
|
func()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Timing loop
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
start = time.time()
|
||||||
|
for _ in range(repeat):
|
||||||
|
func()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
end = time.time()
|
||||||
|
|
||||||
|
# Calculate timing and TFLOPS
|
||||||
|
avg_time_ms = (end - start) / repeat * 1000
|
||||||
|
avg_time_us = avg_time_ms * 1000
|
||||||
|
tflops = 2 * m * n * k / (avg_time_ms * 1e-3) / 1e12
|
||||||
|
gb_s = (m * k + k * n + m * n * 2) / 1e9 / (avg_time_ms * 1e-3)
|
||||||
|
|
||||||
|
benchmark_results["implementations"][name] = {
|
||||||
|
"time_ms": avg_time_ms,
|
||||||
|
"time_us": avg_time_us,
|
||||||
|
"tflops": tflops,
|
||||||
|
"gb_s": gb_s,
|
||||||
|
"diff": {
|
||||||
|
"DeepGEMM": 0.0
|
||||||
|
if name == "DeepGEMM"
|
||||||
|
else calc_diff(func(), C_deepgemm),
|
||||||
|
"Reference": deepgemm_diff
|
||||||
|
if name == "DeepGEMM"
|
||||||
|
else (vllm_triton_diff if name == "vLLM Triton" else vllm_cutlass_diff),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"{name}: {avg_time_ms:.3f} ms, {tflops:.2f} TFLOPS, {gb_s:.2f} GB/s")
|
||||||
|
|
||||||
|
# Calculate speedups
|
||||||
|
baseline = benchmark_results["implementations"]["DeepGEMM"]["time_ms"]
|
||||||
|
for name, data in benchmark_results["implementations"].items():
|
||||||
|
if name != "DeepGEMM":
|
||||||
|
speedup = baseline / data["time_ms"]
|
||||||
|
benchmark_results["implementations"][name]["speedup_vs_deepgemm"] = speedup
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"DeepGEMM is {1 / speedup:.2f}x "
|
||||||
|
f"{'faster' if 1 / speedup > 1 else 'slower'} than {name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_triton_time = benchmark_results["implementations"]["vLLM Triton"]["time_ms"]
|
||||||
|
vllm_cutlass_time = benchmark_results["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||||
|
cutlass_vs_triton = vllm_triton_time / vllm_cutlass_time
|
||||||
|
benchmark_results["implementations"]["vLLM CUTLASS"]["speedup_vs_triton"] = (
|
||||||
|
cutlass_vs_triton
|
||||||
|
)
|
||||||
|
if verbose:
|
||||||
|
print(
|
||||||
|
f"vLLM CUTLASS is {cutlass_vs_triton:.2f}x "
|
||||||
|
f"{'faster' if cutlass_vs_triton > 1 else 'slower'} than vLLM Triton"
|
||||||
|
)
|
||||||
|
|
||||||
|
return benchmark_results
|
||||||
|
|
||||||
|
|
||||||
|
def format_table_row(values, widths):
|
||||||
|
"""Format a row with specified column widths."""
|
||||||
|
return "| " + " | ".join(f"{val:{w}}" for val, w in zip(values, widths)) + " |"
|
||||||
|
|
||||||
|
|
||||||
|
def print_table(headers, rows, title=None):
|
||||||
|
"""Print a table with headers and rows."""
|
||||||
|
if title:
|
||||||
|
print(f"\n{title}")
|
||||||
|
|
||||||
|
# Calculate column widths based on headers and data
|
||||||
|
widths = [
|
||||||
|
max(len(str(h)), max(len(str(row[i])) for row in rows))
|
||||||
|
for i, h in enumerate(headers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create separator line
|
||||||
|
separator = "+-" + "-+-".join("-" * w for w in widths) + "-+"
|
||||||
|
|
||||||
|
# Print table
|
||||||
|
print(separator)
|
||||||
|
print(format_table_row(headers, widths))
|
||||||
|
print(separator)
|
||||||
|
for row in rows:
|
||||||
|
print(format_table_row(row, widths))
|
||||||
|
print(separator)
|
||||||
|
|
||||||
|
|
||||||
|
def format_speedup(value):
|
||||||
|
"""Format speedup value with indicator if it's faster or slower."""
|
||||||
|
return f"{value:.2f}x {'faster' if value > 1.0 else 'slower'}"
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmarks(verbose: bool = False):
|
||||||
|
"""Run benchmarks for a set of common shapes."""
|
||||||
|
print("===== STARTING FP8 GEMM BENCHMARK =====")
|
||||||
|
|
||||||
|
# Make sure we're using the GPU
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available! Tests require GPU.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Print system information
|
||||||
|
print(f"PyTorch version: {torch.__version__}")
|
||||||
|
print(f"CUDA version: {torch.version.cuda}")
|
||||||
|
print(f"Triton version: {triton.__version__}")
|
||||||
|
print(f"Using device: {torch.cuda.get_device_name()}")
|
||||||
|
|
||||||
|
# Enable TF32 for better performance
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
# Set seeds for reproducibility
|
||||||
|
torch.manual_seed(42)
|
||||||
|
torch.cuda.manual_seed(42)
|
||||||
|
|
||||||
|
# Define benchmark shapes (m, n, k)
|
||||||
|
shapes = [
|
||||||
|
(8, 4096, 7168),
|
||||||
|
(8, 7168, 18432),
|
||||||
|
(8, 18432, 7168),
|
||||||
|
(64, 4096, 7168),
|
||||||
|
(64, 7168, 18432),
|
||||||
|
(64, 18432, 7168),
|
||||||
|
(64, 24576, 1536),
|
||||||
|
(64, 32768, 512),
|
||||||
|
(64, 7168, 16384),
|
||||||
|
(128, 4096, 7168),
|
||||||
|
(128, 7168, 18432),
|
||||||
|
(128, 18432, 7168),
|
||||||
|
(1024, 4096, 7168),
|
||||||
|
(1024, 18432, 7168),
|
||||||
|
(2048, 4096, 7168),
|
||||||
|
(4096, 4096, 7168),
|
||||||
|
]
|
||||||
|
shapes = [
|
||||||
|
# (64, 2112, 7168),
|
||||||
|
(64, 24576, 1536),
|
||||||
|
(64, 32768, 512),
|
||||||
|
(64, 7168, 16384),
|
||||||
|
(64, 4096, 7168),
|
||||||
|
(64, 7168, 2048),
|
||||||
|
# (128, 2112, 7168),
|
||||||
|
(128, 24576, 1536),
|
||||||
|
(128, 32768, 512),
|
||||||
|
(128, 7168, 16384),
|
||||||
|
(128, 4096, 7168),
|
||||||
|
(128, 7168, 2048),
|
||||||
|
# (4096, 2112, 7168),
|
||||||
|
(4096, 24576, 1536),
|
||||||
|
(4096, 32768, 512),
|
||||||
|
(4096, 7168, 16384),
|
||||||
|
(4096, 4096, 7168),
|
||||||
|
(4096, 7168, 2048),
|
||||||
|
]
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for m, n, k in shapes:
|
||||||
|
result = benchmark_shape(m, n, k, verbose=verbose)
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
|
# Print results in a nicely formatted table
|
||||||
|
print("\n===== PERFORMANCE COMPARISON =====")
|
||||||
|
|
||||||
|
# Print DeepGEMM table
|
||||||
|
deepgemm_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s"]
|
||||||
|
deepgemm_rows = []
|
||||||
|
for result in all_results:
|
||||||
|
shape = result["shape"]
|
||||||
|
impl_data = result["implementations"]["DeepGEMM"]
|
||||||
|
deepgemm_rows.append(
|
||||||
|
[
|
||||||
|
shape["m"],
|
||||||
|
shape["n"],
|
||||||
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_table(deepgemm_headers, deepgemm_rows, title="DeepGEMM Implementation:")
|
||||||
|
|
||||||
|
# Print vLLM Triton table
|
||||||
|
triton_headers = ["m", "n", "k", "Time (μs)", "TFLOPS", "GB/s", "vs DeepGEMM"]
|
||||||
|
triton_rows = []
|
||||||
|
for result in all_results:
|
||||||
|
shape = result["shape"]
|
||||||
|
impl_data = result["implementations"]["vLLM Triton"]
|
||||||
|
speedup = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||||
|
triton_rows.append(
|
||||||
|
[
|
||||||
|
shape["m"],
|
||||||
|
shape["n"],
|
||||||
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
format_speedup(speedup),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_table(triton_headers, triton_rows, title="vLLM Triton Implementation:")
|
||||||
|
|
||||||
|
# Print vLLM CUTLASS table
|
||||||
|
cutlass_headers = [
|
||||||
|
"m",
|
||||||
|
"n",
|
||||||
|
"k",
|
||||||
|
"Time (μs)",
|
||||||
|
"TFLOPS",
|
||||||
|
"GB/s",
|
||||||
|
"vs DeepGEMM",
|
||||||
|
"vs Triton",
|
||||||
|
]
|
||||||
|
cutlass_rows = []
|
||||||
|
for result in all_results:
|
||||||
|
shape = result["shape"]
|
||||||
|
impl_data = result["implementations"]["vLLM CUTLASS"]
|
||||||
|
vs_deepgemm = impl_data.get("speedup_vs_deepgemm", 1.0)
|
||||||
|
vs_triton = impl_data.get("speedup_vs_triton", 1.0)
|
||||||
|
cutlass_rows.append(
|
||||||
|
[
|
||||||
|
shape["m"],
|
||||||
|
shape["n"],
|
||||||
|
shape["k"],
|
||||||
|
f"{impl_data['time_us']:.1f}",
|
||||||
|
f"{impl_data['tflops']:.1f}",
|
||||||
|
f"{impl_data['gb_s']:.1f}",
|
||||||
|
format_speedup(vs_deepgemm),
|
||||||
|
format_speedup(vs_triton),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_table(cutlass_headers, cutlass_rows, title="vLLM CUTLASS Implementation:")
|
||||||
|
|
||||||
|
# Calculate and print averages
|
||||||
|
print("\n===== AVERAGE PERFORMANCE =====")
|
||||||
|
|
||||||
|
implementations = ["DeepGEMM", "vLLM Triton", "vLLM CUTLASS"]
|
||||||
|
avg_metrics = {
|
||||||
|
impl: {"tflops": 0, "gb_s": 0, "time_ms": 0} for impl in implementations
|
||||||
|
}
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
for impl in implementations:
|
||||||
|
impl_data = result["implementations"][impl]
|
||||||
|
avg_metrics[impl]["tflops"] += impl_data["tflops"]
|
||||||
|
avg_metrics[impl]["gb_s"] += impl_data["gb_s"]
|
||||||
|
avg_metrics[impl]["time_ms"] += impl_data["time_ms"]
|
||||||
|
|
||||||
|
num_shapes = len(all_results)
|
||||||
|
avg_headers = ["Implementation", "Avg TFLOPS", "Avg GB/s", "Avg Time (ms)"]
|
||||||
|
avg_rows = []
|
||||||
|
|
||||||
|
for impl in implementations:
|
||||||
|
avg_tflops = avg_metrics[impl]["tflops"] / num_shapes
|
||||||
|
avg_mem_bw = avg_metrics[impl]["gb_s"] / num_shapes
|
||||||
|
avg_time = avg_metrics[impl]["time_ms"] / num_shapes
|
||||||
|
avg_rows.append(
|
||||||
|
[impl, f"{avg_tflops:.2f}", f"{avg_mem_bw:.2f}", f"{avg_time:.2f}"]
|
||||||
|
)
|
||||||
|
|
||||||
|
print_table(avg_headers, avg_rows)
|
||||||
|
|
||||||
|
# Calculate average speedups
|
||||||
|
avg_speedups = {
|
||||||
|
"DeepGEMM vs vLLM Triton": 0,
|
||||||
|
"DeepGEMM vs vLLM CUTLASS": 0,
|
||||||
|
"vLLM CUTLASS vs vLLM Triton": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
deepgemm_time = result["implementations"]["DeepGEMM"]["time_ms"]
|
||||||
|
vllm_triton_time = result["implementations"]["vLLM Triton"]["time_ms"]
|
||||||
|
vllm_cutlass_time = result["implementations"]["vLLM CUTLASS"]["time_ms"]
|
||||||
|
|
||||||
|
avg_speedups["DeepGEMM vs vLLM Triton"] += vllm_triton_time / deepgemm_time
|
||||||
|
avg_speedups["DeepGEMM vs vLLM CUTLASS"] += vllm_cutlass_time / deepgemm_time
|
||||||
|
avg_speedups["vLLM CUTLASS vs vLLM Triton"] += (
|
||||||
|
vllm_triton_time / vllm_cutlass_time
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n===== AVERAGE SPEEDUPS =====")
|
||||||
|
speedup_headers = ["Comparison", "Speedup"]
|
||||||
|
speedup_rows = []
|
||||||
|
for comparison, total in avg_speedups.items():
|
||||||
|
avg_speedup = total / num_shapes
|
||||||
|
status = "faster" if avg_speedup > 1 else "slower"
|
||||||
|
speedup_rows.append([comparison, f"{avg_speedup:.2f}x {status}"])
|
||||||
|
|
||||||
|
print_table(speedup_headers, speedup_rows)
|
||||||
|
|
||||||
|
# Average accuracy comparison
|
||||||
|
print("\n===== ACCURACY COMPARISON =====")
|
||||||
|
avg_diff = {impl: 0 for impl in implementations}
|
||||||
|
|
||||||
|
for result in all_results:
|
||||||
|
for impl in implementations:
|
||||||
|
avg_diff[impl] += result["implementations"][impl]["diff"]["Reference"]
|
||||||
|
|
||||||
|
diff_headers = ["Implementation", "Avg Diff vs Reference"]
|
||||||
|
diff_rows = []
|
||||||
|
for impl in implementations:
|
||||||
|
diff_rows.append([impl, f"{avg_diff[impl] / num_shapes:.6f}"])
|
||||||
|
|
||||||
|
print_table(diff_headers, diff_rows)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_benchmarks(verbose=False)
|
||||||
64
benchmarks/kernels/graph_machete_bench.py
Normal file
64
benchmarks/kernels/graph_machete_bench.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import math
|
||||||
|
import pickle
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import regex as re
|
||||||
|
import seaborn as sns
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the latency of processing a single batch of "
|
||||||
|
"requests till completion."
|
||||||
|
)
|
||||||
|
parser.add_argument("filename", type=str)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with open(args.filename, "rb") as f:
|
||||||
|
data = pickle.load(f)
|
||||||
|
raw_results: list[TMeasurement] = data["results"]
|
||||||
|
|
||||||
|
results = defaultdict(lambda: list())
|
||||||
|
for v in raw_results:
|
||||||
|
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
|
||||||
|
if result is not None:
|
||||||
|
KN = result.group(1)
|
||||||
|
else:
|
||||||
|
raise Exception("MKN not found")
|
||||||
|
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
|
||||||
|
if result is not None:
|
||||||
|
M = result.group(1)
|
||||||
|
else:
|
||||||
|
raise Exception("MKN not found")
|
||||||
|
|
||||||
|
kernel = v.task_spec.description
|
||||||
|
results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
|
||||||
|
|
||||||
|
rows = int(math.ceil(len(results) / 2))
|
||||||
|
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
|
||||||
|
axs = axs.flatten()
|
||||||
|
for axs_idx, (shape, data) in enumerate(results.items()):
|
||||||
|
plt.sca(axs[axs_idx])
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
sns.lineplot(
|
||||||
|
data=df,
|
||||||
|
x="batch_size",
|
||||||
|
y="median",
|
||||||
|
hue="kernel",
|
||||||
|
style="kernel",
|
||||||
|
markers=True,
|
||||||
|
dashes=False,
|
||||||
|
palette="Dark2",
|
||||||
|
)
|
||||||
|
plt.title(f"Shape: {shape}")
|
||||||
|
plt.ylabel("time (median, s)")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig("graph_machete_bench.pdf")
|
||||||
1
benchmarks/kernels/requirements.txt
Normal file
1
benchmarks/kernels/requirements.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pandas
|
||||||
214
benchmarks/kernels/utils.py
Normal file
214
benchmarks/kernels/utils.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from collections.abc import Callable, Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class CudaGraphBenchParams:
|
||||||
|
num_ops_in_cuda_graph: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class ArgPool:
|
||||||
|
"""
|
||||||
|
When some argument of the benchmarking function is annotated with this type,
|
||||||
|
the benchmarking class (BenchMM) will collapse the argument to a pick a
|
||||||
|
single value from the given list of values, during function invocation.
|
||||||
|
For every invocation during a benchmarking run, it will choose a
|
||||||
|
different value from the list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
values: Iterable[Any]
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return self.values[index]
|
||||||
|
|
||||||
|
|
||||||
|
class Bench:
|
||||||
|
class ArgsIterator:
|
||||||
|
def __init__(self, args_list, kwargs_list):
|
||||||
|
assert len(args_list) == len(kwargs_list)
|
||||||
|
self.args_list = args_list
|
||||||
|
self.kwargs_list = kwargs_list
|
||||||
|
self.n = len(self.args_list)
|
||||||
|
self.idx = 0
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
while True:
|
||||||
|
yield (self.args_list[self.idx], self.kwargs_list[self.idx])
|
||||||
|
self.idx += 1
|
||||||
|
self.idx = self.idx % self.n
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.idx = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_args(self):
|
||||||
|
return self.n
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cuda_graph_params: CudaGraphBenchParams | None,
|
||||||
|
label: str,
|
||||||
|
sub_label: str,
|
||||||
|
description: str,
|
||||||
|
fn: Callable,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.cuda_graph_params = cuda_graph_params
|
||||||
|
self.use_cuda_graph = self.cuda_graph_params is not None
|
||||||
|
self.label = label
|
||||||
|
self.sub_label = sub_label
|
||||||
|
self.description = description
|
||||||
|
self.fn = fn
|
||||||
|
|
||||||
|
# Process args
|
||||||
|
self._args = args
|
||||||
|
self._kwargs = kwargs
|
||||||
|
self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
|
||||||
|
self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
|
||||||
|
|
||||||
|
# Cudagraph runner
|
||||||
|
self.g = None
|
||||||
|
if self.use_cuda_graph:
|
||||||
|
self.g = self.get_cuda_graph_runner()
|
||||||
|
|
||||||
|
# benchmark run params
|
||||||
|
self.min_run_time = 1
|
||||||
|
|
||||||
|
def collapse_argpool(self, *args, **kwargs):
|
||||||
|
argpool_args = [arg for arg in args if isinstance(arg, ArgPool)] + [
|
||||||
|
arg for arg in kwargs.values() if isinstance(arg, ArgPool)
|
||||||
|
]
|
||||||
|
if len(argpool_args) == 0:
|
||||||
|
return [args], [kwargs]
|
||||||
|
|
||||||
|
# Make sure all argpools are of the same size
|
||||||
|
argpool_size = len(argpool_args[0].values)
|
||||||
|
assert all([argpool_size == len(arg.values) for arg in argpool_args])
|
||||||
|
|
||||||
|
# create copies of the args
|
||||||
|
args_list = []
|
||||||
|
kwargs_list = []
|
||||||
|
for _ in range(argpool_size):
|
||||||
|
args_list.append(args)
|
||||||
|
kwargs_list.append(kwargs.copy())
|
||||||
|
|
||||||
|
for i in range(argpool_size):
|
||||||
|
# collapse args; Just pick the ith value
|
||||||
|
args_list[i] = tuple(
|
||||||
|
[arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# collapse kwargs
|
||||||
|
kwargs_i = kwargs_list[i]
|
||||||
|
arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
|
||||||
|
for k in arg_pool_keys:
|
||||||
|
# again just pick the ith value
|
||||||
|
kwargs_i[k] = kwargs_i[k][i]
|
||||||
|
kwargs_list[i] = kwargs_i
|
||||||
|
|
||||||
|
return args_list, kwargs_list
|
||||||
|
|
||||||
|
def get_cuda_graph_runner(self):
|
||||||
|
assert self.use_cuda_graph
|
||||||
|
assert self.args_iterator is not None
|
||||||
|
|
||||||
|
num_graph_ops = self.cuda_graph_params.num_ops_in_cuda_graph
|
||||||
|
|
||||||
|
# warmup
|
||||||
|
args_it = self.args_iterator.__next__()
|
||||||
|
for _ in range(2):
|
||||||
|
args, kwargs = next(args_it)
|
||||||
|
self.fn(*args, **kwargs)
|
||||||
|
|
||||||
|
self.args_iterator.reset()
|
||||||
|
args_it = self.args_iterator.__next__()
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
g = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
for _ in range(num_graph_ops):
|
||||||
|
args, kwargs = next(args_it)
|
||||||
|
self.fn(*args, **kwargs)
|
||||||
|
return g
|
||||||
|
|
||||||
|
def run_cudagrah(self) -> TMeasurement:
|
||||||
|
assert self.use_cuda_graph
|
||||||
|
globals = {"g": self.g}
|
||||||
|
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="g.replay()",
|
||||||
|
globals=globals,
|
||||||
|
label=(
|
||||||
|
f"{self.label}"
|
||||||
|
f" | cugraph {self.cuda_graph_params.num_ops_in_cuda_graph} ops"
|
||||||
|
),
|
||||||
|
sub_label=self.sub_label,
|
||||||
|
description=self.description,
|
||||||
|
).blocked_autorange(min_run_time=self.min_run_time)
|
||||||
|
|
||||||
|
def run_eager(self) -> TMeasurement:
|
||||||
|
setup = None
|
||||||
|
stmt = None
|
||||||
|
globals = None
|
||||||
|
|
||||||
|
has_arg_pool = self.args_iterator.n_args > 1
|
||||||
|
if has_arg_pool:
|
||||||
|
setup = """
|
||||||
|
args_iterator.reset()
|
||||||
|
args_it = args_iterator.__next__()
|
||||||
|
"""
|
||||||
|
stmt = """
|
||||||
|
args, kwargs = next(args_it)
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
"""
|
||||||
|
globals = {"fn": self.fn, "args_iterator": self.args_iterator}
|
||||||
|
else:
|
||||||
|
# no arg pool. Just use the args and kwargs directly
|
||||||
|
self.args_iterator.reset()
|
||||||
|
args_it = self.args_iterator.__next__()
|
||||||
|
args, kwargs = next(args_it)
|
||||||
|
|
||||||
|
setup = ""
|
||||||
|
stmt = """
|
||||||
|
fn(*args, **kwargs)
|
||||||
|
"""
|
||||||
|
globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
|
||||||
|
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt=stmt,
|
||||||
|
setup=setup,
|
||||||
|
globals=globals,
|
||||||
|
label=self.label,
|
||||||
|
sub_label=self.sub_label,
|
||||||
|
description=self.description,
|
||||||
|
).blocked_autorange(min_run_time=self.min_run_time)
|
||||||
|
|
||||||
|
def run(self) -> TMeasurement:
|
||||||
|
timer = None
|
||||||
|
if self.use_cuda_graph: # noqa SIM108
|
||||||
|
timer = self.run_cudagrah()
|
||||||
|
else:
|
||||||
|
timer = self.run_eager()
|
||||||
|
if not timer.meets_confidence() or timer.has_warnings:
|
||||||
|
print("Doesn't meet confidence - re-running bench ...")
|
||||||
|
return self.run()
|
||||||
|
return timer
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
if exc_type:
|
||||||
|
print(f"exc type {exc_type}")
|
||||||
|
print(f"exc value {exc_value}")
|
||||||
|
print(f"exc traceback {traceback}")
|
||||||
104
benchmarks/kernels/weight_shapes.py
Normal file
104
benchmarks/kernels/weight_shapes.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Weight Shapes are in the format
|
||||||
|
# ([K, N], TP_SPLIT_DIM)
|
||||||
|
# Example:
|
||||||
|
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 14336, N = 4096
|
||||||
|
# - TP2 : K = 7168, N = 4096
|
||||||
|
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
|
||||||
|
# - TP1 : K = 4096, N = 6144
|
||||||
|
# - TP4 : K = 4096, N = 1536
|
||||||
|
|
||||||
|
# TP1 shapes
|
||||||
|
WEIGHT_SHAPES = {
|
||||||
|
"mistralai/Mistral-7B-v0.1": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-7b-hf": [
|
||||||
|
([4096, 12288], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 22016], 1),
|
||||||
|
([11008, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3-8b": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-13b-hf": [
|
||||||
|
([5120, 15360], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 27648], 1),
|
||||||
|
([13824, 5120], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-2-70b-hf": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3.1-405b-hf": [
|
||||||
|
([16384, 18432], 1),
|
||||||
|
([16384, 16384], 0),
|
||||||
|
([16384, 106496], 1),
|
||||||
|
([53248, 16384], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct": [
|
||||||
|
([4096, 6144], 1),
|
||||||
|
([4096, 4096], 0),
|
||||||
|
([4096, 28672], 1),
|
||||||
|
([14336, 4096], 0),
|
||||||
|
],
|
||||||
|
"meta-llama/Llama-3.3-70B-Instruct": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 57344], 1),
|
||||||
|
([28672, 8192], 0),
|
||||||
|
],
|
||||||
|
"mistralai/Mistral-Large-Instruct-2407": [
|
||||||
|
([12288, 14336], 1),
|
||||||
|
([12288, 12288], 0),
|
||||||
|
([12288, 57344], 1),
|
||||||
|
([28672, 12288], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-7B-Instruct": [
|
||||||
|
([3584, 4608], 1),
|
||||||
|
([3584, 3584], 0),
|
||||||
|
([3584, 37888], 1),
|
||||||
|
([18944, 3584], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-32B-Instruct": [
|
||||||
|
([5120, 7168], 1),
|
||||||
|
([5120, 5120], 0),
|
||||||
|
([5120, 55296], 1),
|
||||||
|
([27648, 5120], 0),
|
||||||
|
],
|
||||||
|
"Qwen/Qwen2.5-72B-Instruct": [
|
||||||
|
([8192, 10240], 1),
|
||||||
|
([8192, 8192], 0),
|
||||||
|
([8192, 59136], 1),
|
||||||
|
([29568, 8192], 0),
|
||||||
|
],
|
||||||
|
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [
|
||||||
|
([2048, 3072], 1),
|
||||||
|
([2048, 4096], 1),
|
||||||
|
([2048, 2048], 0),
|
||||||
|
([2048, 576], 0),
|
||||||
|
([2048, 21888], 1),
|
||||||
|
([10944, 2048], 0),
|
||||||
|
([2048, 2816], 1),
|
||||||
|
([1408, 2048], 0),
|
||||||
|
],
|
||||||
|
"CohereLabs/c4ai-command-a-03-2025": [
|
||||||
|
([12288, 14336], 1),
|
||||||
|
([12288, 12288], 0),
|
||||||
|
([12288, 73728], 1),
|
||||||
|
([36864, 12288], 0),
|
||||||
|
],
|
||||||
|
}
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
PORT=8000
|
|
||||||
MODEL=$1
|
|
||||||
TOKENS=$2
|
|
||||||
|
|
||||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
|
|
||||||
-v $PWD/data:/data \
|
|
||||||
ghcr.io/huggingface/text-generation-inference:1.4.0 \
|
|
||||||
--model-id $MODEL \
|
|
||||||
--sharded false \
|
|
||||||
--max-input-length 1024 \
|
|
||||||
--max-total-tokens 2048 \
|
|
||||||
--max-best-of 5 \
|
|
||||||
--max-concurrent-requests 5000 \
|
|
||||||
--max-batch-total-tokens $TOKENS
|
|
||||||
178
benchmarks/multi_turn/README.md
Normal file
178
benchmarks/multi_turn/README.md
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
# Benchmark KV Cache Offloading with Multi-Turn Conversations
|
||||||
|
|
||||||
|
The requirements (pip) for `benchmark_serving_multi_turn.py` can be found in `requirements.txt`
|
||||||
|
|
||||||
|
First start serving your model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||||
|
|
||||||
|
vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
|
||||||
|
```
|
||||||
|
|
||||||
|
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
|
||||||
|
|
||||||
|
## Synthetic Multi-Turn Conversations
|
||||||
|
|
||||||
|
Download the following text file (used for generation of synthetic conversations)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
wget https://www.gutenberg.org/ebooks/1184.txt.utf-8
|
||||||
|
mv 1184.txt.utf-8 pg1184.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
The filename `pg1184.txt` is used in `generate_multi_turn.json` (see `"text_files"`).
|
||||||
|
|
||||||
|
But you may use other text files if you prefer (using this specific file is not required).
|
||||||
|
|
||||||
|
Then run the benchmarking script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
|
||||||
|
|
||||||
|
python benchmark_serving_multi_turn.py --model $MODEL_PATH --served-model-name Llama \
|
||||||
|
--input-file generate_multi_turn.json --num-clients 2 --max-active-conversations 6
|
||||||
|
```
|
||||||
|
|
||||||
|
You can edit the file `generate_multi_turn.json` to change the conversation parameters (number of turns, etc.).
|
||||||
|
|
||||||
|
If successful, you will see the following output
|
||||||
|
|
||||||
|
```bash
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
Statistics summary:
|
||||||
|
runtime_sec = 215.810
|
||||||
|
requests_per_sec = 0.769
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
count mean std min 25% 50% 75% 90% 99% max
|
||||||
|
ttft_ms 166.0 78.22 67.63 45.91 59.94 62.26 64.43 69.66 353.18 567.54
|
||||||
|
tpot_ms 166.0 25.37 0.57 24.40 25.07 25.31 25.50 25.84 27.50 28.05
|
||||||
|
latency_ms 166.0 2591.07 326.90 1998.53 2341.62 2573.01 2860.10 3003.50 3268.46 3862.94
|
||||||
|
input_num_turns 166.0 7.43 4.57 1.00 3.00 7.00 11.00 13.00 17.00 17.00
|
||||||
|
input_num_tokens 166.0 2006.20 893.56 522.00 1247.75 2019.00 2718.00 3233.00 3736.45 3899.00
|
||||||
|
output_num_tokens 166.0 100.01 11.80 80.00 91.00 99.00 109.75 116.00 120.00 120.00
|
||||||
|
output_num_chunks 166.0 99.01 11.80 79.00 90.00 98.00 108.75 115.00 119.00 119.00
|
||||||
|
----------------------------------------------------------------------------------------------------
|
||||||
|
```
|
||||||
|
|
||||||
|
If you run with `--warmup-step`, the summary will also include `warmup_runtime_sec`
|
||||||
|
and `total_runtime_incl_warmup_sec` (while `runtime_sec` continues to reflect the
|
||||||
|
benchmark-only runtime so the reported throughput stays comparable).
|
||||||
|
|
||||||
|
### JSON configuration file for synthetic conversations generation
|
||||||
|
|
||||||
|
The input flag `--input-file` is used to determine the input conversations for the benchmark.<br/>
|
||||||
|
When the input is a JSON file with the field `"filetype": "generate_conversations"` the tool will generate synthetic multi-turn (questions and answers) conversations.
|
||||||
|
|
||||||
|
The file `generate_multi_turn.json` is an example file.
|
||||||
|
|
||||||
|
The file must contain the sections `prompt_input` and `prompt_output`.
|
||||||
|
|
||||||
|
The `prompt_input` section must contain `num_turns`, `prefix_num_tokens` and `num_tokens`:
|
||||||
|
|
||||||
|
* `num_turns` - Number of total turns in the conversation (both user & assistant).<br/>
|
||||||
|
The final value will always be rounded to an even number so each user turn has a reply.
|
||||||
|
* `prefix_num_tokens` - Tokens added at the start of only the **first user turn** in a conversation (unique per conversation).
|
||||||
|
* `num_tokens` - Total token length of each **user** message (one turn).
|
||||||
|
|
||||||
|
The `prompt_output` section must contain `num_tokens`:
|
||||||
|
|
||||||
|
* `num_tokens` - Total token length of each **assistant** message (one turn).
|
||||||
|
|
||||||
|
### Random distributions for synthetic conversations generation
|
||||||
|
|
||||||
|
When creating an input JSON file (such as `generate_multi_turn.json`),<br/>
|
||||||
|
every numeric field (such as `num_turns` or `num_tokens`) requires a distribution.<br/>
|
||||||
|
The distribution determines how to randomly sample values for the field.
|
||||||
|
|
||||||
|
The available distributions are listed below.
|
||||||
|
|
||||||
|
**Note:** The optional `max` field (for lognormal, zipf, and poisson) can be used to cap sampled values at an upper bound.</br>
|
||||||
|
Can be used to make sure that the total number of tokens in every request does not exceed `--max-model-len`.
|
||||||
|
|
||||||
|
#### constant
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"distribution": "constant",
|
||||||
|
"value": 500
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
* `value` - the fixed integer value (always returns the same number).
|
||||||
|
|
||||||
|
#### uniform
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"distribution": "uniform",
|
||||||
|
"min": 12,
|
||||||
|
"max": 18
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
* `min` - minimum value (inclusive).
|
||||||
|
* `max` - maximum value (inclusive), should be equal or larger than min.
|
||||||
|
|
||||||
|
#### lognormal
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"distribution": "lognormal",
|
||||||
|
"average": 1000,
|
||||||
|
"max": 5000
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can parameterize the lognormal distribution in one of two ways:
|
||||||
|
|
||||||
|
Using the average and optional median ratio:
|
||||||
|
|
||||||
|
* `average` - target average value of the distribution.
|
||||||
|
* `median_ratio` - the ratio of the median to the average; controls the skewness. Must be in the range (0, 1).
|
||||||
|
|
||||||
|
Using the parameters of the underlying normal distribution:
|
||||||
|
|
||||||
|
* `mean` - mean of the underlying normal distribution.
|
||||||
|
* `sigma` - standard deviation of the underlying normal distribution.
|
||||||
|
|
||||||
|
#### zipf
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"distribution": "zipf",
|
||||||
|
"alpha": 1.2,
|
||||||
|
"max": 100
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
* `alpha` - skew parameter (> 1). Larger values produce stronger skew toward smaller integers.
|
||||||
|
|
||||||
|
#### poisson
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"distribution": "poisson",
|
||||||
|
"alpha": 10,
|
||||||
|
"max": 50
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
* `alpha` - expected value (λ). Also the variance of the distribution.
|
||||||
|
|
||||||
|
## ShareGPT Conversations
|
||||||
|
|
||||||
|
To run with the ShareGPT data, download the following ShareGPT dataset:
|
||||||
|
`https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json`
|
||||||
|
|
||||||
|
Use the `convert_sharegpt_to_openai.py` script to convert the dataset to a format supported by `benchmark_serving_multi_turn.py`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python convert_sharegpt_to_openai.py sharegpt_20230401_clean_lang_split.json sharegpt_conv_128.json --seed=99 --max-items=128
|
||||||
|
```
|
||||||
|
|
||||||
|
The script will convert the ShareGPT dataset to a dataset with the standard user/assistant roles.
|
||||||
|
|
||||||
|
The flag `--max-items=128` is used to sample 128 conversations from the original dataset (change as needed).
|
||||||
|
|
||||||
|
Use the output JSON file `sharegpt_conv_128.json` as the `--input-file` for `benchmark_serving_multi_turn.py`.
|
||||||
600
benchmarks/multi_turn/bench_dataset.py
Normal file
600
benchmarks/multi_turn/bench_dataset.py
Normal file
@@ -0,0 +1,600 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from statistics import mean
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np # type: ignore
|
||||||
|
import pandas as pd # type: ignore
|
||||||
|
from bench_utils import (
|
||||||
|
TEXT_SEPARATOR,
|
||||||
|
Color,
|
||||||
|
logger,
|
||||||
|
)
|
||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
|
||||||
|
# Conversation ID is a string (e.g: "UzTK34D")
|
||||||
|
ConvId = str
|
||||||
|
|
||||||
|
# A list of dicts (dicts with keys "id" and "messages")
|
||||||
|
ShareGptConversations = list[dict[str, Any]]
|
||||||
|
|
||||||
|
# A list of dicts (dicts with keys "role" and "content")
|
||||||
|
MessagesList = list[dict[str, str]]
|
||||||
|
|
||||||
|
# Map conversation ID to conversation messages
|
||||||
|
ConversationsMap = list[ConvId, MessagesList]
|
||||||
|
|
||||||
|
|
||||||
|
class Distribution(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def sample(self, size: int = 1) -> np.ndarray:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UniformDistribution(Distribution):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_val: int | float,
|
||||||
|
max_val: int | float,
|
||||||
|
is_integer: bool = True,
|
||||||
|
) -> None:
|
||||||
|
self.min_val = min_val
|
||||||
|
self.max_val = max_val
|
||||||
|
self.is_integer = is_integer
|
||||||
|
|
||||||
|
def sample(self, size: int = 1) -> np.ndarray:
|
||||||
|
if self.is_integer:
|
||||||
|
return np.random.randint(
|
||||||
|
int(self.min_val), int(self.max_val + 1), size=size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return np.random.uniform(self.min_val, self.max_val, size=size)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"UniformDistribution[{self.min_val}, {self.max_val}]"
|
||||||
|
|
||||||
|
|
||||||
|
class ConstantDistribution(Distribution):
|
||||||
|
def __init__(self, value: int | float) -> None:
|
||||||
|
self.value = value
|
||||||
|
self.max_val = value
|
||||||
|
|
||||||
|
def sample(self, size: int = 1) -> np.ndarray:
|
||||||
|
return np.full(shape=size, fill_value=self.value)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"Constant[{self.value}]"
|
||||||
|
|
||||||
|
|
||||||
|
class ZipfDistribution(Distribution):
|
||||||
|
def __init__(self, alpha: float, max_val: int | None = None) -> None:
|
||||||
|
self.alpha = alpha
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
def sample(self, size: int = 1) -> np.ndarray:
|
||||||
|
samples = np.random.zipf(self.alpha, size=size)
|
||||||
|
if self.max_val:
|
||||||
|
samples = np.minimum(samples, self.max_val)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ZipfDistribution[{self.alpha}]"
|
||||||
|
|
||||||
|
|
||||||
|
class PoissonDistribution(Distribution):
|
||||||
|
def __init__(self, alpha: float, max_val: int | None = None) -> None:
|
||||||
|
self.alpha = alpha
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
def sample(self, size: int = 1) -> np.ndarray:
|
||||||
|
samples = np.random.poisson(self.alpha, size=size)
|
||||||
|
if self.max_val:
|
||||||
|
samples = np.minimum(samples, self.max_val)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"PoissonDistribution[{self.alpha}]"
|
||||||
|
|
||||||
|
|
||||||
|
class LognormalDistribution(Distribution):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
mean: float | None = None,
|
||||||
|
sigma: float | None = None,
|
||||||
|
average: int | None = None,
|
||||||
|
median_ratio: float | None = None,
|
||||||
|
max_val: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.average = average
|
||||||
|
self.median_ratio = median_ratio
|
||||||
|
self.max_val = max_val
|
||||||
|
|
||||||
|
if average is not None:
|
||||||
|
if average < 1:
|
||||||
|
raise ValueError("Lognormal average must be positive")
|
||||||
|
|
||||||
|
if mean or sigma:
|
||||||
|
raise ValueError(
|
||||||
|
"When using lognormal average, you can't provide mean/sigma"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.median_ratio is None:
|
||||||
|
# Default value that provides relatively wide range of values
|
||||||
|
self.median_ratio = 0.85
|
||||||
|
|
||||||
|
# Calculate mean/sigma of np.random.lognormal based on the average
|
||||||
|
mean, sigma = self._generate_lognormal_by_median(
|
||||||
|
target_average=self.average, median_ratio=self.median_ratio
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if mean is None or sigma is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Must provide both mean and sigma if average is not used"
|
||||||
|
)
|
||||||
|
|
||||||
|
if mean <= 0 or sigma < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Lognormal mean must be positive and sigma must be non-negative"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mean and standard deviation of the underlying normal distribution
|
||||||
|
# Based on numpy.random.lognormal
|
||||||
|
self.mean = mean
|
||||||
|
self.sigma = sigma
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_lognormal_by_median(
|
||||||
|
target_average: int, median_ratio: float
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Compute (mu, sigma) for a lognormal distribution given:
|
||||||
|
- a target average (mean of the distribution)
|
||||||
|
- a ratio of median / mean (controls skewness), assume mean > median
|
||||||
|
|
||||||
|
Background:
|
||||||
|
If Z ~ Normal(mu, sigma^2), then X = exp(Z) ~ LogNormal(mu, sigma).
|
||||||
|
* mean(X) = exp(mu + sigma^2 / 2)
|
||||||
|
* median(X) = exp(mu)
|
||||||
|
|
||||||
|
So:
|
||||||
|
median / mean = exp(mu) / exp(mu + sigma^2 / 2)
|
||||||
|
= exp(-sigma^2 / 2)
|
||||||
|
|
||||||
|
Rearranging:
|
||||||
|
sigma^2 = 2 * ln(mean / median)
|
||||||
|
mu = ln(median)
|
||||||
|
|
||||||
|
This gives a unique (mu, sigma) for any valid mean and median.
|
||||||
|
"""
|
||||||
|
# Check input validity: median must be smaller than mean
|
||||||
|
if median_ratio <= 0 or median_ratio >= 1:
|
||||||
|
raise ValueError("median_ratio must be in range (0, 1)")
|
||||||
|
|
||||||
|
target_median = target_average * median_ratio
|
||||||
|
|
||||||
|
# Solve sigma^2 = 2 * ln(mean / median)
|
||||||
|
sigma = np.sqrt(2 * np.log(target_average / target_median))
|
||||||
|
mu = np.log(target_median)
|
||||||
|
|
||||||
|
return mu, sigma
|
||||||
|
|
||||||
|
def sample(self, size: int = 1) -> np.ndarray:
|
||||||
|
samples = np.random.lognormal(mean=self.mean, sigma=self.sigma, size=size)
|
||||||
|
|
||||||
|
if self.average is not None:
|
||||||
|
# Scale to average
|
||||||
|
samples *= self.average / samples.mean()
|
||||||
|
|
||||||
|
if self.max_val:
|
||||||
|
samples = np.minimum(samples, self.max_val)
|
||||||
|
|
||||||
|
return np.round(samples).astype(int)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
if self.average:
|
||||||
|
return (
|
||||||
|
f"LognormalDistribution[{self.average}, "
|
||||||
|
f"{self.median_ratio}, {self.max_val}]"
|
||||||
|
)
|
||||||
|
return f"LognormalDistribution[{self.mean}, {self.sigma}, {self.max_val}]"
|
||||||
|
|
||||||
|
|
||||||
|
class GenConvArgs(NamedTuple):
|
||||||
|
num_conversations: int
|
||||||
|
text_files: list[str]
|
||||||
|
input_num_turns: Distribution
|
||||||
|
input_common_prefix_num_tokens: Distribution
|
||||||
|
input_prefix_num_tokens: Distribution
|
||||||
|
input_num_tokens: Distribution
|
||||||
|
output_num_tokens: Distribution
|
||||||
|
print_stats: bool
|
||||||
|
|
||||||
|
|
||||||
|
def verify_field_exists(
|
||||||
|
conf: dict, field_name: str, section: str, subsection: str
|
||||||
|
) -> None:
|
||||||
|
if field_name not in conf:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing field '{field_name}' in {section=} and {subsection=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_distribution(
|
||||||
|
conf: dict, section: str, subsection: str, optional: bool = False
|
||||||
|
) -> Distribution:
|
||||||
|
# section can be "prompt_input" or "prompt_output" (both required)
|
||||||
|
conf = conf[section]
|
||||||
|
|
||||||
|
if optional and subsection not in conf:
|
||||||
|
# Optional subsection, if not found assume the value is always 0
|
||||||
|
return ConstantDistribution(0)
|
||||||
|
|
||||||
|
# subsection can be "num_turns", "num_tokens" or "prefix_num_tokens"
|
||||||
|
if subsection not in conf:
|
||||||
|
raise ValueError(f"Missing subsection {subsection} in section {section}")
|
||||||
|
|
||||||
|
conf = conf[subsection]
|
||||||
|
|
||||||
|
distribution = conf.get("distribution")
|
||||||
|
if distribution is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing field 'distribution' in {section=} and {subsection=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if distribution == "constant":
|
||||||
|
verify_field_exists(conf, "value", section, subsection)
|
||||||
|
return ConstantDistribution(conf["value"])
|
||||||
|
|
||||||
|
elif distribution == "zipf":
|
||||||
|
verify_field_exists(conf, "alpha", section, subsection)
|
||||||
|
max_val = conf.get("max", None)
|
||||||
|
return ZipfDistribution(conf["alpha"], max_val=max_val)
|
||||||
|
|
||||||
|
elif distribution == "poisson":
|
||||||
|
verify_field_exists(conf, "alpha", section, subsection)
|
||||||
|
max_val = conf.get("max", None)
|
||||||
|
return PoissonDistribution(conf["alpha"], max_val=max_val)
|
||||||
|
|
||||||
|
elif distribution == "lognormal":
|
||||||
|
max_val = conf.get("max", None)
|
||||||
|
|
||||||
|
if "average" in conf:
|
||||||
|
# Infer lognormal mean/sigma (numpy) from input average
|
||||||
|
median_ratio = conf.get("median_ratio", None)
|
||||||
|
return LognormalDistribution(
|
||||||
|
average=conf["average"], median_ratio=median_ratio, max_val=max_val
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use mean/sigma directly (for full control over the distribution)
|
||||||
|
verify_field_exists(conf, "mean", section, subsection)
|
||||||
|
verify_field_exists(conf, "sigma", section, subsection)
|
||||||
|
return LognormalDistribution(
|
||||||
|
mean=conf["mean"], sigma=conf["sigma"], max_val=max_val
|
||||||
|
)
|
||||||
|
|
||||||
|
elif distribution == "uniform":
|
||||||
|
verify_field_exists(conf, "min", section, subsection)
|
||||||
|
verify_field_exists(conf, "max", section, subsection)
|
||||||
|
|
||||||
|
min_value = conf["min"]
|
||||||
|
max_value = conf["max"]
|
||||||
|
|
||||||
|
assert min_value > 0
|
||||||
|
assert min_value <= max_value
|
||||||
|
|
||||||
|
is_integer = isinstance(min_value, int) and isinstance(max_value, int)
|
||||||
|
return UniformDistribution(min_value, max_value, is_integer)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown distribution: {distribution}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_input_json_file(conf: dict) -> GenConvArgs:
|
||||||
|
# Validate the input file
|
||||||
|
assert isinstance(conf, dict)
|
||||||
|
required_fields = [
|
||||||
|
"filetype",
|
||||||
|
"num_conversations",
|
||||||
|
"text_files",
|
||||||
|
"prompt_input",
|
||||||
|
"prompt_output",
|
||||||
|
]
|
||||||
|
for field in required_fields:
|
||||||
|
assert field in conf, f"Missing field {field} in input {conf}"
|
||||||
|
|
||||||
|
assert conf["filetype"] == "generate_conversations"
|
||||||
|
|
||||||
|
assert conf["num_conversations"] > 0, "num_conversations should be larger than zero"
|
||||||
|
|
||||||
|
text_files = conf["text_files"]
|
||||||
|
|
||||||
|
assert isinstance(text_files, list), "Field 'text_files' should be a list"
|
||||||
|
assert len(text_files) > 0, (
|
||||||
|
"Field 'text_files' should be a list with at least one file"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the parameters for the prompt input/output workload
|
||||||
|
input_num_turns = get_random_distribution(conf, "prompt_input", "num_turns")
|
||||||
|
input_num_tokens = get_random_distribution(conf, "prompt_input", "num_tokens")
|
||||||
|
input_common_prefix_num_tokens = get_random_distribution(
|
||||||
|
conf, "prompt_input", "common_prefix_num_tokens", optional=True
|
||||||
|
)
|
||||||
|
input_prefix_num_tokens = get_random_distribution(
|
||||||
|
conf, "prompt_input", "prefix_num_tokens"
|
||||||
|
)
|
||||||
|
output_num_tokens = get_random_distribution(conf, "prompt_output", "num_tokens")
|
||||||
|
|
||||||
|
print_stats: bool = conf.get("print_stats", False)
|
||||||
|
assert isinstance(print_stats, bool), (
|
||||||
|
"Field 'print_stats' should be either 'true' or 'false'"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = GenConvArgs(
|
||||||
|
num_conversations=conf["num_conversations"],
|
||||||
|
text_files=text_files,
|
||||||
|
input_num_turns=input_num_turns,
|
||||||
|
input_common_prefix_num_tokens=input_common_prefix_num_tokens,
|
||||||
|
input_prefix_num_tokens=input_prefix_num_tokens,
|
||||||
|
input_num_tokens=input_num_tokens,
|
||||||
|
output_num_tokens=output_num_tokens,
|
||||||
|
print_stats=print_stats,
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def print_conv_stats(conversations: ConversationsMap, tokenizer: AutoTokenizer) -> None:
|
||||||
|
# Collect statistics
|
||||||
|
conv_stats: list[dict[Any, Any]] = []
|
||||||
|
req_stats: list[int] = []
|
||||||
|
|
||||||
|
print("\nCollecting statistics...")
|
||||||
|
for messages in conversations.values():
|
||||||
|
# messages is a list of dicts
|
||||||
|
user_tokens: list[int] = []
|
||||||
|
assistant_tokens: list[int] = []
|
||||||
|
request_tokens: list[int] = []
|
||||||
|
|
||||||
|
req_tokens = 0
|
||||||
|
for m in messages:
|
||||||
|
content = m["content"]
|
||||||
|
num_tokens = len(tokenizer(content).input_ids)
|
||||||
|
|
||||||
|
if m["role"] == "user":
|
||||||
|
user_tokens.append(num_tokens)
|
||||||
|
# New user prompt including all chat history
|
||||||
|
req_tokens += num_tokens
|
||||||
|
request_tokens.append(req_tokens)
|
||||||
|
|
||||||
|
elif m["role"] == "assistant":
|
||||||
|
assistant_tokens.append(num_tokens)
|
||||||
|
# Update assistant answer
|
||||||
|
# (will be part of chat history for the next user prompt)
|
||||||
|
req_tokens += num_tokens
|
||||||
|
|
||||||
|
item_stats = {
|
||||||
|
"conversation_turns": len(messages),
|
||||||
|
"user_tokens": mean(user_tokens),
|
||||||
|
"assistant_tokens": mean(assistant_tokens),
|
||||||
|
}
|
||||||
|
|
||||||
|
conv_stats.append(item_stats)
|
||||||
|
req_stats.extend(request_tokens)
|
||||||
|
|
||||||
|
# Print statistics
|
||||||
|
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99]
|
||||||
|
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
df = pd.DataFrame(conv_stats)
|
||||||
|
print(df.describe(percentiles=percentiles).transpose())
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
print(f"{Color.YELLOW}Request statistics:{Color.RESET}")
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
df = pd.DataFrame(req_stats, columns=["request_tokens"])
|
||||||
|
print(df.describe(percentiles=percentiles).transpose())
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_conversations(
|
||||||
|
args: GenConvArgs, tokenizer: AutoTokenizer
|
||||||
|
) -> ConversationsMap:
|
||||||
|
# Text for all user prompts
|
||||||
|
# (text from the input text files will be appended to this line)
|
||||||
|
base_prompt_text = "Please rewrite the following text and add more content: "
|
||||||
|
base_prompt_token_count = len(
|
||||||
|
tokenizer.encode(base_prompt_text, add_special_tokens=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"{Color.PURPLE}Generating conversations...{Color.RESET}")
|
||||||
|
logger.info(args)
|
||||||
|
|
||||||
|
list_of_tokens = []
|
||||||
|
|
||||||
|
for filename in args.text_files:
|
||||||
|
# Load text file that will be used to generate prompts
|
||||||
|
with open(filename) as file:
|
||||||
|
data = file.read()
|
||||||
|
tokens_in_file = tokenizer.encode(data, add_special_tokens=False)
|
||||||
|
list_of_tokens.extend(tokens_in_file)
|
||||||
|
logger.info(
|
||||||
|
f"Loaded {len(tokens_in_file)} tokens from file {filename}, "
|
||||||
|
f"total tokens so far: {len(list_of_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
conversations: ConversationsMap = {}
|
||||||
|
conv_id = 0
|
||||||
|
|
||||||
|
# Generate number of turns for every conversation
|
||||||
|
turn_count: np.ndarray = args.input_num_turns.sample(args.num_conversations)
|
||||||
|
|
||||||
|
# Turn count should be at least 2 (one user prompt and one assistant answer)
|
||||||
|
turn_count = np.maximum(turn_count, 2)
|
||||||
|
|
||||||
|
# Round up to an even number (every user prompt should have an answer)
|
||||||
|
turn_count = turn_count + (turn_count % 2)
|
||||||
|
|
||||||
|
# Generate number of prefix tokens for every conversation
|
||||||
|
conv_prefix_tokens: np.ndarray = args.input_prefix_num_tokens.sample(
|
||||||
|
args.num_conversations
|
||||||
|
)
|
||||||
|
|
||||||
|
# Used to reduce shared text between conversations
|
||||||
|
# (jump/skip over text sections between conversations)
|
||||||
|
base_offset = 0
|
||||||
|
|
||||||
|
# Common prefix size for all conversations (only 1 sample required)
|
||||||
|
common_prefix_text = ""
|
||||||
|
common_prefix_tokens: int = args.input_common_prefix_num_tokens.sample(1)[0]
|
||||||
|
if common_prefix_tokens > 0:
|
||||||
|
# Using "." at the end to separate sentences
|
||||||
|
common_prefix_text = (
|
||||||
|
tokenizer.decode(list_of_tokens[: common_prefix_tokens - 2]) + "."
|
||||||
|
)
|
||||||
|
base_offset += common_prefix_tokens
|
||||||
|
|
||||||
|
for conv_id in tqdm(
|
||||||
|
range(args.num_conversations),
|
||||||
|
total=args.num_conversations,
|
||||||
|
desc="Generating conversations",
|
||||||
|
unit="conv",
|
||||||
|
):
|
||||||
|
# Generate a single conversation
|
||||||
|
messages: MessagesList = []
|
||||||
|
|
||||||
|
nturns = turn_count[conv_id]
|
||||||
|
|
||||||
|
# User prompt token count per turn (with lower limit)
|
||||||
|
input_token_count: np.ndarray = args.input_num_tokens.sample(nturns).astype(int)
|
||||||
|
input_token_count = np.maximum(input_token_count, base_prompt_token_count)
|
||||||
|
|
||||||
|
# Assistant answer token count per turn (with lower limit)
|
||||||
|
output_token_count: np.ndarray = args.output_num_tokens.sample(nturns).astype(
|
||||||
|
int
|
||||||
|
)
|
||||||
|
output_token_count = np.maximum(output_token_count, 1)
|
||||||
|
|
||||||
|
user_turn = True
|
||||||
|
for turn_id in range(nturns):
|
||||||
|
if user_turn:
|
||||||
|
role = "user"
|
||||||
|
num_tokens = input_token_count[turn_id]
|
||||||
|
|
||||||
|
# Generate the user prompt,
|
||||||
|
# use a unique prefix (the conv_id) for each conversation
|
||||||
|
# (to avoid shared prefix between conversations)
|
||||||
|
content = f"{conv_id} is a nice number... "
|
||||||
|
|
||||||
|
if len(common_prefix_text) > 0 and turn_id == 0:
|
||||||
|
content = common_prefix_text + content
|
||||||
|
|
||||||
|
# Update the number of tokens left for the content
|
||||||
|
num_tokens -= len(tokenizer.encode(content, add_special_tokens=False))
|
||||||
|
|
||||||
|
if turn_id == 0:
|
||||||
|
prefix_num_tokens = conv_prefix_tokens[conv_id]
|
||||||
|
if prefix_num_tokens > 0:
|
||||||
|
# Add prefix text (context) to the first turn
|
||||||
|
start_offset = base_offset
|
||||||
|
end_offset = start_offset + prefix_num_tokens
|
||||||
|
assert len(list_of_tokens) > end_offset, (
|
||||||
|
"Not enough input text to generate "
|
||||||
|
f"{prefix_num_tokens} tokens for the "
|
||||||
|
f"prefix text ({start_offset=}, {end_offset=})"
|
||||||
|
)
|
||||||
|
|
||||||
|
content += f"{conv_id}, " + tokenizer.decode(
|
||||||
|
list_of_tokens[start_offset:end_offset]
|
||||||
|
)
|
||||||
|
base_offset += prefix_num_tokens
|
||||||
|
|
||||||
|
# Add the actual user prompt/question after the prefix text
|
||||||
|
content += base_prompt_text
|
||||||
|
num_tokens -= base_prompt_token_count
|
||||||
|
|
||||||
|
if num_tokens > 0:
|
||||||
|
# Add text from the input file (to reach the desired token count)
|
||||||
|
start_offset = base_offset + turn_id * input_token_count.max()
|
||||||
|
end_offset = start_offset + num_tokens
|
||||||
|
assert len(list_of_tokens) > end_offset, (
|
||||||
|
f"Not enough input text to generate {num_tokens} tokens "
|
||||||
|
f"for the prompt ({start_offset=}, {end_offset=})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert tokens back to text
|
||||||
|
content += tokenizer.decode(list_of_tokens[start_offset:end_offset])
|
||||||
|
else:
|
||||||
|
role = "assistant"
|
||||||
|
# This content will not be used as input to the LLM server
|
||||||
|
# (actual answers will be used instead).
|
||||||
|
# Content is only required to determine the min_tokens/max_tokens
|
||||||
|
# (inputs to the LLM server).
|
||||||
|
num_tokens = output_token_count[turn_id]
|
||||||
|
assert len(list_of_tokens) > num_tokens, (
|
||||||
|
f"Not enough input text to generate {num_tokens} "
|
||||||
|
"tokens for assistant content"
|
||||||
|
)
|
||||||
|
content = tokenizer.decode(list_of_tokens[:num_tokens])
|
||||||
|
|
||||||
|
# Append the user/assistant message to the list of messages
|
||||||
|
messages.append({"role": role, "content": content})
|
||||||
|
user_turn = not user_turn
|
||||||
|
|
||||||
|
# Add the new conversation
|
||||||
|
conversations[f"CONV_ID_{conv_id}"] = messages
|
||||||
|
|
||||||
|
# Increase base offset for the next conversation
|
||||||
|
base_offset += nturns
|
||||||
|
|
||||||
|
if args.print_stats:
|
||||||
|
print_conv_stats(conversations, tokenizer)
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
|
def conversations_list_to_dict(input_list: ShareGptConversations) -> ConversationsMap:
|
||||||
|
conversations: ConversationsMap = {}
|
||||||
|
|
||||||
|
for item in input_list:
|
||||||
|
conv_id: str = item["id"]
|
||||||
|
assert isinstance(conv_id, str)
|
||||||
|
|
||||||
|
assert conv_id not in conversations, (
|
||||||
|
f"Conversation ID {conv_id} found more than once in the input"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages: MessagesList = item["messages"]
|
||||||
|
assert isinstance(messages, list), (
|
||||||
|
f"Conversation messages should be a list (ID: {conv_id})"
|
||||||
|
)
|
||||||
|
assert len(messages) > 0, f"Conversation with no messages (ID: {conv_id})"
|
||||||
|
|
||||||
|
conversations[conv_id] = messages
|
||||||
|
|
||||||
|
logger.info(f"Using {len(conversations)} unique conversations (IDs)")
|
||||||
|
assert len(conversations) == len(input_list)
|
||||||
|
|
||||||
|
# Print statistics about the selected conversations
|
||||||
|
stats: list[dict[str, Any]] = []
|
||||||
|
for conv_data in conversations.values():
|
||||||
|
stats.append({"num_turns": len(conv_data)})
|
||||||
|
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
print(f"{Color.YELLOW}Conversations statistics:{Color.RESET}")
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||||
|
conv_stats = pd.DataFrame(stats).describe(percentiles=percentiles)
|
||||||
|
print(conv_stats.transpose())
|
||||||
|
print(TEXT_SEPARATOR)
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
|
def conversations_dict_to_list(input_dict: ConversationsMap) -> ShareGptConversations:
|
||||||
|
output: ShareGptConversations = []
|
||||||
|
for conv_id, conv_data in input_dict.items():
|
||||||
|
new_item = {"id": conv_id, "messages": conv_data}
|
||||||
|
output.append(new_item)
|
||||||
|
|
||||||
|
return output
|
||||||
28
benchmarks/multi_turn/bench_utils.py
Normal file
28
benchmarks/multi_turn/bench_utils.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class Color(Enum):
|
||||||
|
RED = "\033[91m"
|
||||||
|
GREEN = "\033[92m"
|
||||||
|
BLUE = "\033[94m"
|
||||||
|
PURPLE = "\033[95m"
|
||||||
|
CYAN = "\033[96m"
|
||||||
|
YELLOW = "\033[93m"
|
||||||
|
RESET = "\033[0m"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
TEXT_SEPARATOR = "-" * 100
|
||||||
|
|
||||||
|
# Configure the logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] - %(message)s",
|
||||||
|
datefmt="%d-%m-%Y %H:%M:%S",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
1666
benchmarks/multi_turn/benchmark_serving_multi_turn.py
Normal file
1666
benchmarks/multi_turn/benchmark_serving_multi_turn.py
Normal file
File diff suppressed because it is too large
Load Diff
354
benchmarks/multi_turn/convert_sharegpt_to_openai.py
Normal file
354
benchmarks/multi_turn/convert_sharegpt_to_openai.py
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Download dataset from:
|
||||||
|
https://huggingface.co/datasets/philschmid/sharegpt-raw/blob/main/sharegpt_20230401_clean_lang_split.json
|
||||||
|
|
||||||
|
Convert to OpenAI API:
|
||||||
|
export INPUT_FILE=sharegpt_20230401_clean_lang_split.json
|
||||||
|
python convert_sharegpt_to_openai.py $INPUT_FILE sharegpt_conv_128.json --max-items=128
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from statistics import mean
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pandas as pd # type: ignore
|
||||||
|
import tqdm # type: ignore
|
||||||
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def has_non_english_chars(text: str) -> bool:
|
||||||
|
return not text.isascii()
|
||||||
|
|
||||||
|
|
||||||
|
def content_is_valid(
|
||||||
|
content: str, min_content_len: int | None, max_content_len: int | None
|
||||||
|
) -> bool:
|
||||||
|
if min_content_len and len(content) < min_content_len:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if max_content_len and len(content) > max_content_len:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return has_non_english_chars(content)
|
||||||
|
|
||||||
|
|
||||||
|
def print_stats(
|
||||||
|
conversations: "list[dict[Any, Any]]", tokenizer: AutoTokenizer | None = None
|
||||||
|
) -> None:
|
||||||
|
# Collect statistics
|
||||||
|
stats = []
|
||||||
|
|
||||||
|
print("\nCollecting statistics...")
|
||||||
|
for item in tqdm.tqdm(conversations):
|
||||||
|
# item has "id" and "messages"
|
||||||
|
messages = item["messages"]
|
||||||
|
|
||||||
|
user_turns = 0
|
||||||
|
assistant_turns = 0
|
||||||
|
user_words = 0
|
||||||
|
assistant_words = 0
|
||||||
|
conv_chars = 0
|
||||||
|
|
||||||
|
user_tokens: list[int] = []
|
||||||
|
assistant_tokens: list[int] = []
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
content = m["content"]
|
||||||
|
conv_chars += len(content)
|
||||||
|
content_num_words = content.count(" ") + 1
|
||||||
|
|
||||||
|
num_tokens = 0
|
||||||
|
if tokenizer:
|
||||||
|
num_tokens = len(tokenizer(m["content"]).input_ids)
|
||||||
|
|
||||||
|
if m["role"] == "user":
|
||||||
|
user_turns += 1
|
||||||
|
user_words += content_num_words
|
||||||
|
if tokenizer:
|
||||||
|
user_tokens.append(num_tokens)
|
||||||
|
|
||||||
|
elif m["role"] == "assistant":
|
||||||
|
assistant_turns += 1
|
||||||
|
assistant_words += content_num_words
|
||||||
|
if tokenizer:
|
||||||
|
assistant_tokens.append(num_tokens)
|
||||||
|
|
||||||
|
# assert user_turns == assistant_turns, \
|
||||||
|
# f"Invalid conversation ID {item['id']}"
|
||||||
|
|
||||||
|
conv_words = user_words + assistant_words
|
||||||
|
item_stats = {
|
||||||
|
"user_turns": user_turns,
|
||||||
|
"assistant_turns": assistant_turns,
|
||||||
|
"user_words": user_words,
|
||||||
|
"assistant_words": assistant_words,
|
||||||
|
"conv_turns": len(messages),
|
||||||
|
"conv_words": conv_words,
|
||||||
|
"conv_characters": conv_chars,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(user_tokens) > 0:
|
||||||
|
item_stats["user_tokens"] = int(mean(user_tokens))
|
||||||
|
|
||||||
|
if len(assistant_tokens) > 0:
|
||||||
|
item_stats["assistant_tokens"] = int(mean(assistant_tokens))
|
||||||
|
|
||||||
|
stats.append(item_stats)
|
||||||
|
|
||||||
|
print("\nStatistics:")
|
||||||
|
percentiles = [0.25, 0.5, 0.75, 0.9, 0.99, 0.999, 0.9999]
|
||||||
|
df = pd.DataFrame(stats)
|
||||||
|
print(df.describe(percentiles=percentiles).transpose())
|
||||||
|
|
||||||
|
|
||||||
|
def convert_sharegpt_to_openai(
|
||||||
|
seed: int,
|
||||||
|
input_file: str,
|
||||||
|
output_file: str,
|
||||||
|
max_items: int | None,
|
||||||
|
min_content_len: int | None = None,
|
||||||
|
max_content_len: int | None = None,
|
||||||
|
min_turns: int | None = None,
|
||||||
|
max_turns: int | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if min_turns and max_turns:
|
||||||
|
assert min_turns <= max_turns
|
||||||
|
|
||||||
|
if min_content_len and max_content_len:
|
||||||
|
# Verify that min is not larger than max if both were given
|
||||||
|
assert min_content_len <= max_content_len
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Input parameters:\n{seed=}, {max_items=}, {min_content_len=},"
|
||||||
|
f" {max_content_len=}, {min_turns=}, {max_turns=}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
tokenizer = None
|
||||||
|
if model is not None:
|
||||||
|
print(f"Loading tokenizer from: {model}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||||
|
|
||||||
|
# Read the ShareGPT JSON file
|
||||||
|
print(f"Reading file: {input_file}")
|
||||||
|
with open(input_file, encoding="utf-8") as f:
|
||||||
|
# Should be a list of dicts
|
||||||
|
# Each dict should have "id" (string) and "conversations" (list of dicts)
|
||||||
|
sharegpt_data = json.load(f)
|
||||||
|
|
||||||
|
assert isinstance(sharegpt_data, list), "Input file should contain a list of dicts"
|
||||||
|
|
||||||
|
print(f"Total items in input file: {len(sharegpt_data):,}")
|
||||||
|
|
||||||
|
print(f"Shuffling dataset with seed {seed}")
|
||||||
|
random.shuffle(sharegpt_data)
|
||||||
|
|
||||||
|
# Map conversation ID to the all the messages
|
||||||
|
conversation_parts: dict[str, list[Any]] = {}
|
||||||
|
|
||||||
|
for item in tqdm.tqdm(sharegpt_data):
|
||||||
|
assert "id" in item, "Missing key 'id'"
|
||||||
|
assert "conversations" in item, "Missing key 'conversations'"
|
||||||
|
|
||||||
|
# Conversation ID (e.g: "hiWPlMD") and part/session (0, 1, 2, etc.)
|
||||||
|
conv_id, _ = item["id"].split("_")
|
||||||
|
new_turns = item["conversations"]
|
||||||
|
|
||||||
|
if conv_id not in conversation_parts:
|
||||||
|
# Start new conversation
|
||||||
|
conversation_parts[conv_id] = []
|
||||||
|
elif len(conversation_parts[conv_id]) > 0 and len(new_turns) > 0:
|
||||||
|
prev_turns = conversation_parts[conv_id][-1]
|
||||||
|
if prev_turns[-1]["from"] == new_turns[0]["from"]:
|
||||||
|
new_turns = new_turns[1:]
|
||||||
|
|
||||||
|
if len(new_turns) > 0:
|
||||||
|
# We assume that parts are in order in the ShareGPT dataset
|
||||||
|
conversation_parts[conv_id].append(new_turns)
|
||||||
|
|
||||||
|
dataset: list[dict[str, Any]] = []
|
||||||
|
for conv_id, conv_parts in conversation_parts.items():
|
||||||
|
new_item = {"id": conv_id}
|
||||||
|
|
||||||
|
conversations: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
# Merge all parts
|
||||||
|
for conv_part in conv_parts:
|
||||||
|
conversations.extend(conv_part)
|
||||||
|
|
||||||
|
if len(conversations) > 0:
|
||||||
|
new_item["conversations"] = conversations
|
||||||
|
dataset.append(new_item)
|
||||||
|
|
||||||
|
print(f"Total unique conversations (IDs) in input file: {len(dataset):,}")
|
||||||
|
|
||||||
|
# Final output data
|
||||||
|
final_openai_dataset: list[dict] = []
|
||||||
|
|
||||||
|
# Filter conversations from the ShareGPT dataset and convert to OpenAI format
|
||||||
|
for item in tqdm.tqdm(dataset):
|
||||||
|
messages: list[dict] = []
|
||||||
|
|
||||||
|
assert "id" in item, "Missing key 'id'"
|
||||||
|
assert "conversations" in item, "Missing key 'conversations'"
|
||||||
|
|
||||||
|
conv_id = item["id"]
|
||||||
|
conversations = item["conversations"]
|
||||||
|
|
||||||
|
if min_turns is not None and len(conversations) < min_turns:
|
||||||
|
# Skip short conversations
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert each message in the conversation, up to max_turns if specified
|
||||||
|
for i, turn in enumerate(conversations):
|
||||||
|
assert "from" in turn and "value" in turn, (
|
||||||
|
f"Invalid conversation ID {conv_id} - missing 'from' or 'value'"
|
||||||
|
)
|
||||||
|
|
||||||
|
role = None
|
||||||
|
turn_from = turn["from"]
|
||||||
|
|
||||||
|
if turn_from in {"human", "user"}:
|
||||||
|
role = "user"
|
||||||
|
elif turn_from in {"gpt", "bing", "chatgpt", "bard"}:
|
||||||
|
role = "assistant"
|
||||||
|
elif turn_from == "system":
|
||||||
|
role = "system"
|
||||||
|
|
||||||
|
assert role is not None, (
|
||||||
|
f"Invalid conversation ID {conv_id} - 'from'='{turn_from}' is invalid"
|
||||||
|
)
|
||||||
|
|
||||||
|
if i == 0 and role != "user":
|
||||||
|
# If the first message is from assistant (gpt), skip it.
|
||||||
|
# this happens when the conversation is a follow-up
|
||||||
|
# to a previous conversation (from the same user).
|
||||||
|
continue
|
||||||
|
|
||||||
|
if max_turns is not None and i >= max_turns:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Convert message to OpenAI format (with "role" and "content")
|
||||||
|
content = turn["value"]
|
||||||
|
messages.append({"role": role, "content": content})
|
||||||
|
|
||||||
|
# Add the converted conversation to the OpenAI format
|
||||||
|
if len(messages) > 0:
|
||||||
|
valid_messages = True
|
||||||
|
|
||||||
|
# First turn should always be from the user
|
||||||
|
user_turn = True
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
# Make sure that turns alternate between user and assistant
|
||||||
|
if (user_turn and m["role"] != "user") or (
|
||||||
|
not user_turn and m["role"] != "assistant"
|
||||||
|
):
|
||||||
|
valid_messages = False
|
||||||
|
break
|
||||||
|
|
||||||
|
user_turn = not user_turn
|
||||||
|
|
||||||
|
content = m["content"]
|
||||||
|
valid_messages = content_is_valid(
|
||||||
|
content, min_content_len, max_content_len
|
||||||
|
)
|
||||||
|
if not valid_messages:
|
||||||
|
break
|
||||||
|
|
||||||
|
if valid_messages is True:
|
||||||
|
final_openai_dataset.append({"id": conv_id, "messages": messages})
|
||||||
|
|
||||||
|
assert len(final_openai_dataset) > 0, "Final number of conversations is zero"
|
||||||
|
|
||||||
|
print_stats(final_openai_dataset)
|
||||||
|
|
||||||
|
print_stats_again = False
|
||||||
|
if max_items is not None and len(final_openai_dataset) > max_items:
|
||||||
|
print(f"\n\nSampling {max_items} items from the dataset...")
|
||||||
|
print_stats_again = True
|
||||||
|
final_openai_dataset = random.sample(final_openai_dataset, max_items)
|
||||||
|
|
||||||
|
if print_stats_again:
|
||||||
|
# Print stats after the dataset changed
|
||||||
|
print_stats(final_openai_dataset, tokenizer)
|
||||||
|
|
||||||
|
# Write the converted data to a new JSON file
|
||||||
|
final_size = len(final_openai_dataset)
|
||||||
|
print(f"\nTotal conversations converted (after filtering): {final_size:,}")
|
||||||
|
print(f"\nWriting file: {output_file}")
|
||||||
|
with open(output_file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(final_openai_dataset, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Convert ShareGPT dataset to OpenAI API format"
|
||||||
|
)
|
||||||
|
parser.add_argument("input_file", help="Path to the input ShareGPT JSON file")
|
||||||
|
parser.add_argument(
|
||||||
|
"output_file", help="Path to the output OpenAI format JSON file"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed", type=int, default=0, help="Seed for random number generators"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-items",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum number of items in the output file",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-turns",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Minimum number of turns per conversation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-turns",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Maximum number of turns per conversation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-content-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Min number of characters in the messages' content",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-content-len",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Max number of characters in the messages' content",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="LLM model, only the tokenizer will be used",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
convert_sharegpt_to_openai(
|
||||||
|
args.seed,
|
||||||
|
args.input_file,
|
||||||
|
args.output_file,
|
||||||
|
args.max_items,
|
||||||
|
args.min_content_len,
|
||||||
|
args.max_content_len,
|
||||||
|
args.min_turns,
|
||||||
|
args.max_turns,
|
||||||
|
args.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
34
benchmarks/multi_turn/generate_multi_turn.json
Normal file
34
benchmarks/multi_turn/generate_multi_turn.json
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
{
|
||||||
|
"filetype": "generate_conversations",
|
||||||
|
"num_conversations": 24,
|
||||||
|
"text_files": ["pg1184.txt"],
|
||||||
|
"print_stats": false,
|
||||||
|
"prompt_input": {
|
||||||
|
"num_turns": {
|
||||||
|
"distribution": "uniform",
|
||||||
|
"min": 12,
|
||||||
|
"max": 18
|
||||||
|
},
|
||||||
|
"common_prefix_num_tokens": {
|
||||||
|
"distribution": "constant",
|
||||||
|
"value": 500
|
||||||
|
},
|
||||||
|
"prefix_num_tokens": {
|
||||||
|
"distribution": "lognormal",
|
||||||
|
"average": 1000,
|
||||||
|
"max": 5000
|
||||||
|
},
|
||||||
|
"num_tokens": {
|
||||||
|
"distribution": "uniform",
|
||||||
|
"min": 120,
|
||||||
|
"max": 160
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"prompt_output": {
|
||||||
|
"num_tokens": {
|
||||||
|
"distribution": "uniform",
|
||||||
|
"min": 80,
|
||||||
|
"max": 120
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
6
benchmarks/multi_turn/requirements.txt
Normal file
6
benchmarks/multi_turn/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
numpy>=1.24
|
||||||
|
pandas>=2.0.0
|
||||||
|
aiohttp>=3.10
|
||||||
|
transformers>=4.46
|
||||||
|
xlsxwriter>=3.2.1
|
||||||
|
tqdm>=4.66
|
||||||
64
benchmarks/overheads/benchmark_hashing.py
Normal file
64
benchmarks/overheads/benchmark_hashing.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import cProfile
|
||||||
|
import pstats
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
# A very long prompt, total number of tokens is about 15k.
|
||||||
|
LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
|
||||||
|
LONG_PROMPT = " ".join(LONG_PROMPT)
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
llm = LLM(
|
||||||
|
model=args.model,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)
|
||||||
|
profiler = cProfile.Profile()
|
||||||
|
|
||||||
|
print("------warm up------")
|
||||||
|
for i in range(3):
|
||||||
|
output = llm.generate(LONG_PROMPT, sampling_params)
|
||||||
|
print(output[0].outputs[0].text)
|
||||||
|
|
||||||
|
print("------start generating------")
|
||||||
|
for i in range(3):
|
||||||
|
profiler.runctx(
|
||||||
|
"llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
|
||||||
|
)
|
||||||
|
|
||||||
|
# analyze the runtime of hashing function
|
||||||
|
stats = pstats.Stats(profiler)
|
||||||
|
stats.sort_stats("cumulative")
|
||||||
|
total_time = 0
|
||||||
|
total_calls = 0
|
||||||
|
for func in stats.stats:
|
||||||
|
if "hash_of_block" in func[2]:
|
||||||
|
total_time = stats.stats[func][3]
|
||||||
|
total_calls = stats.stats[func][0]
|
||||||
|
percentage = (total_time / stats.total_tt) * 100
|
||||||
|
print(
|
||||||
|
f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="Benchmark the performance of hashing function in"
|
||||||
|
"automatic prefix caching."
|
||||||
|
)
|
||||||
|
parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
|
||||||
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
|
parser.add_argument("--output-len", type=int, default=10)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-prefix-caching", action="store_true", help="enable prefix caching"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
129
benchmarks/run_structured_output_benchmark.sh
Executable file
129
benchmarks/run_structured_output_benchmark.sh
Executable file
@@ -0,0 +1,129 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# default values
|
||||||
|
MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||||
|
BACKEND=${BACKEND:-"vllm"}
|
||||||
|
DATASET=${DATASET:-"xgrammar_bench"}
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||||
|
PORT=${PORT:-8000}
|
||||||
|
STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1}
|
||||||
|
TOTAL_SECONDS=${TOTAL_SECONDS:-90}
|
||||||
|
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300}
|
||||||
|
TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"}
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
echo "Usage: $0 [options]"
|
||||||
|
echo "Options:"
|
||||||
|
echo " --model MODEL Model to benchmark (default: $MODEL)"
|
||||||
|
echo " --backend BACKEND Backend to use (default: $BACKEND)"
|
||||||
|
echo " --dataset DATASET Dataset to use (default: $DATASET)"
|
||||||
|
echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)"
|
||||||
|
echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)"
|
||||||
|
echo " --port PORT Port to use (default: $PORT)"
|
||||||
|
echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)"
|
||||||
|
echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)"
|
||||||
|
echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)"
|
||||||
|
echo " -h, --help Show this help message and exit"
|
||||||
|
exit 0
|
||||||
|
}
|
||||||
|
|
||||||
|
# parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
--model)
|
||||||
|
MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--backend)
|
||||||
|
BACKEND="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--dataset)
|
||||||
|
DATASET="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--max-new-tokens)
|
||||||
|
MAX_NEW_TOKENS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--output-dir)
|
||||||
|
OUTPUT_DIR="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--port)
|
||||||
|
PORT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--structured-output-ratio)
|
||||||
|
STRUCTURED_OUTPUT_RATIO="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--tokenizer-mode)
|
||||||
|
TOKENIZER_MODE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--total-seconds)
|
||||||
|
TOTAL_SECONDS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown argument: $1\n"
|
||||||
|
usage
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Define QPS values to test
|
||||||
|
QPS_VALUES=(25 20 15 10 5 1)
|
||||||
|
|
||||||
|
# Common parameters
|
||||||
|
COMMON_PARAMS="--backend $BACKEND \
|
||||||
|
--model $MODEL \
|
||||||
|
--dataset $DATASET \
|
||||||
|
--structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
|
||||||
|
--save-results \
|
||||||
|
--result-dir $OUTPUT_DIR \
|
||||||
|
--output-len $MAX_NEW_TOKENS \
|
||||||
|
--port $PORT \
|
||||||
|
--tokenizer-mode $TOKENIZER_MODE"
|
||||||
|
|
||||||
|
echo "Starting structured output benchmark with model: $MODEL"
|
||||||
|
echo "Backend: $BACKEND"
|
||||||
|
echo "Dataset: $DATASET"
|
||||||
|
echo "Results will be saved to: $OUTPUT_DIR"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
|
||||||
|
# Run benchmarks with different QPS values
|
||||||
|
for qps in "${QPS_VALUES[@]}"; do
|
||||||
|
echo "Running benchmark with QPS: $qps"
|
||||||
|
|
||||||
|
# Get git hash and branch for the filename
|
||||||
|
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||||
|
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
|
||||||
|
|
||||||
|
# Construct filename for this run
|
||||||
|
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||||
|
|
||||||
|
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
|
||||||
|
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
|
||||||
|
echo "Running benchmark with $NUM_PROMPTS prompts"
|
||||||
|
|
||||||
|
# Run the benchmark
|
||||||
|
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||||
|
--request-rate $qps \
|
||||||
|
--result-filename "$FILENAME" \
|
||||||
|
--num-prompts $NUM_PROMPTS
|
||||||
|
|
||||||
|
echo "Completed benchmark with QPS: $qps"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "All benchmarks completed!"
|
||||||
|
echo "Results saved to: $OUTPUT_DIR"
|
||||||
19
benchmarks/structured_schemas/structured_schema_1.json
Normal file
19
benchmarks/structured_schemas/structured_schema_1.json
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" },
|
||||||
|
"email": { "type": "string" },
|
||||||
|
"street": { "type": "string" },
|
||||||
|
"city": { "type": "string" },
|
||||||
|
"state": { "type": "string" },
|
||||||
|
"zip": { "type": "string" },
|
||||||
|
"phone": { "type": "string" },
|
||||||
|
"website": { "type": "string" },
|
||||||
|
"company": { "type": "string" },
|
||||||
|
"age": { "type": "integer" }
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"name",
|
||||||
|
"email"
|
||||||
|
]
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user